From 1f307594244dd4c274b64d028823cbcfff302f37 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Wed, 1 Jun 2022 08:52:47 +0200 Subject: [MRG] numItermax in 64 bits in EMD solver (#380) * Correct test_mm_convergence for cupy * Fix bug where number of iterations is limited to 2^31 * Update RELEASES.md * Replace size_t with long long * Use uint64_t instead of long long --- RELEASES.md | 5 ++++- ot/lp/EMD.h | 5 +++-- ot/lp/EMD_wrapper.cpp | 4 ++-- ot/lp/emd_wrap.pyx | 9 +++++---- ot/lp/network_simplex_simple.h | 8 ++++---- ot/lp/network_simplex_simple_omp.h | 6 +++--- test/test_unbalanced.py | 38 ++++++++++++++++++++------------------ 7 files changed, 41 insertions(+), 34 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index e761e64..fdaff59 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,7 +10,10 @@ - Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU (Issue #371, PR #373) -- Fixed an issue where hitting iteration limits would be reported to stderr by std::cerr regardless of Python's stderr stream status. +- Fixed an issue where Sinkhorn solver assumed a symmetric cost matrix (Issue #374, PR #375) +- Fixed an issue where hitting iteration limits would be reported to stderr by std::cerr regardless of Python's stderr stream status (PR #377) +- Fixed an issue where the metric argument in ot.dist did not allow a callable parameter (Issue #378, PR #379) +- Fixed an issue where the max number of iterations in ot.emd was not allow to go beyond 2^31 (PR #380) ## 0.8.2 diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index 8a1f9ac..b56f060 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -18,6 +18,7 @@ #include #include +#include typedef unsigned int node_id_type; @@ -28,8 +29,8 @@ enum ProblemType { MAX_ITER_REACHED }; -int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter); -int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads); +int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter); +int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads); diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 2bdc172..457216b 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -20,7 +20,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, - double* alpha, double* beta, double *cost, int maxIter) { + double* alpha, double* beta, double *cost, uint64_t maxIter) { // beware M and C are stored in row major C style!!! using namespace lemon; @@ -122,7 +122,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G, - double* alpha, double* beta, double *cost, int maxIter, int numThreads) { + double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) { // beware M and C are stored in row major C style!!! using namespace lemon_omp; diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 42e08f4..e5cec89 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -14,13 +14,14 @@ from ..utils import dist cimport cython cimport libc.math as math +from libc.stdint cimport uint64_t import warnings cdef extern from "EMD.h": - int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil - int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil + int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil + int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -39,7 +40,7 @@ def check_result(result_code): @cython.boundscheck(False) @cython.wraparound(False) -def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -75,7 +76,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod target histogram M : (ns,nt) numpy.ndarray, float64 loss matrix - max_iter : int + max_iter : uint64_t The maximum number of iterations before stopping the optimization algorithm if it has not converged. diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 5b1038f..9612a8a 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -233,7 +233,7 @@ namespace lemon { /// mixed order in the internal data structure. /// In special cases, it could lead to better overall performance, /// but it is usually slower. Therefore it is disabled by default. - NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) : + NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters) : _graph(graph), //_arc_id(graph), _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), MAX(std::numeric_limits::max()), @@ -242,7 +242,7 @@ namespace lemon { { // Reset data structures reset(); - max_iter=maxiters; + max_iter = maxiters; } /// The type of the flow amounts, capacity bounds and supply values @@ -293,7 +293,7 @@ namespace lemon { private: - size_t max_iter; + uint64_t max_iter; TEMPLATE_DIGRAPH_TYPEDEFS(GR); typedef std::vector IntVector; @@ -1427,7 +1427,7 @@ namespace lemon { // Perform heuristic initial pivots if (!initialPivots()) return UNBOUNDED; - size_t iter_number=0; + uint64_t iter_number = 0; //pivot.setDantzig(true); // Execute the Network Simplex algorithm while (pivot.findEnteringArc()) { diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h index dde84fd..5f19d73 100644 --- a/ot/lp/network_simplex_simple_omp.h +++ b/ot/lp/network_simplex_simple_omp.h @@ -244,7 +244,7 @@ namespace lemon_omp { /// mixed order in the internal data structure. /// In special cases, it could lead to better overall performance, /// but it is usually slower. Therefore it is disabled by default. - NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1) : + NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters = 0, int numThreads=-1) : _graph(graph), //_arc_id(graph), _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), MAX(std::numeric_limits::max()), @@ -317,7 +317,7 @@ namespace lemon_omp { private: - size_t max_iter; + uint64_t max_iter; int num_threads; TEMPLATE_DIGRAPH_TYPEDEFS(GR); @@ -1563,7 +1563,7 @@ namespace lemon_omp { // Perform heuristic initial pivots if (!initialPivots()) return UNBOUNDED; - size_t iter_number = 0; + uint64_t iter_number = 0; // Execute the Network Simplex algorithm while (pivot.findEnteringArc()) { if ((++iter_number <= max_iter&&max_iter > 0) || max_iter<=0) { diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 02b3fc3..fc40df0 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -295,26 +295,27 @@ def test_mm_convergence(nx): x = rng.randn(n, 2) rng = np.random.RandomState(75) y = rng.randn(n, 2) - a = ot.utils.unif(n) - b = ot.utils.unif(n) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) M = ot.dist(x, y) M = M / M.max() reg_m = 100 - a, b, M = nx.from_numpy(a, b, M) + a, b, M = nx.from_numpy(a_np, b_np, M) G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', - verbose=True, log=True) - loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2( - a, b, M, reg_m, div='kl', verbose=True)) + verbose=False, log=True) + loss_kl = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div='kl', verbose=True) + ) G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False, log=True) # check if the marginals come close to the true ones when large reg - np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b_np, atol=1e-03) # check if mm_unbalanced2 returns the correct loss np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl, @@ -324,15 +325,16 @@ def test_mm_convergence(nx): a_np, b_np = np.array([]), np.array([]) a, b = nx.from_numpy(a_np, b_np) - G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl') - G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2') - np.testing.assert_allclose(G_kl_null, G_kl) - np.testing.assert_allclose(G_l2_null, G_l2) + G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', verbose=False) + G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False) + np.testing.assert_allclose(nx.to_numpy(G_kl_null), nx.to_numpy(G_kl)) + np.testing.assert_allclose(nx.to_numpy(G_l2_null), nx.to_numpy(G_l2)) # test when G0 is given G0 = ot.emd(a, b, M) + G0_np = nx.to_numpy(G0) reg_m = 10000 - G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0) - G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0) - np.testing.assert_allclose(G0, G_kl, atol=1e-05) - np.testing.assert_allclose(G0, G_l2, atol=1e-05) + G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0, verbose=False) + G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0, verbose=False) + np.testing.assert_allclose(G0_np, nx.to_numpy(G_kl), atol=1e-05) + np.testing.assert_allclose(G0_np, nx.to_numpy(G_l2), atol=1e-05) -- cgit v1.2.3