summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2022-06-01 08:52:47 +0200
committerGitHub <noreply@github.com>2022-06-01 08:52:47 +0200
commit1f307594244dd4c274b64d028823cbcfff302f37 (patch)
treef3302cde2a26a8b5c3869d722269e91c23a3ae5b
parent951209ac3f01c86b35d3beff4679ce47e47c0872 (diff)
[MRG] numItermax in 64 bits in EMD solver (#380)
* Correct test_mm_convergence for cupy * Fix bug where number of iterations is limited to 2^31 * Update RELEASES.md * Replace size_t with long long * Use uint64_t instead of long long
-rw-r--r--RELEASES.md5
-rw-r--r--ot/lp/EMD.h5
-rw-r--r--ot/lp/EMD_wrapper.cpp4
-rw-r--r--ot/lp/emd_wrap.pyx9
-rw-r--r--ot/lp/network_simplex_simple.h8
-rw-r--r--ot/lp/network_simplex_simple_omp.h6
-rw-r--r--test/test_unbalanced.py38
7 files changed, 41 insertions, 34 deletions
diff --git a/RELEASES.md b/RELEASES.md
index e761e64..fdaff59 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -10,7 +10,10 @@
- Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU
(Issue #371, PR #373)
-- Fixed an issue where hitting iteration limits would be reported to stderr by std::cerr regardless of Python's stderr stream status.
+- Fixed an issue where Sinkhorn solver assumed a symmetric cost matrix (Issue #374, PR #375)
+- Fixed an issue where hitting iteration limits would be reported to stderr by std::cerr regardless of Python's stderr stream status (PR #377)
+- Fixed an issue where the metric argument in ot.dist did not allow a callable parameter (Issue #378, PR #379)
+- Fixed an issue where the max number of iterations in ot.emd was not allow to go beyond 2^31 (PR #380)
## 0.8.2
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
index 8a1f9ac..b56f060 100644
--- a/ot/lp/EMD.h
+++ b/ot/lp/EMD.h
@@ -18,6 +18,7 @@
#include <iostream>
#include <vector>
+#include <cstdint>
typedef unsigned int node_id_type;
@@ -28,8 +29,8 @@ enum ProblemType {
MAX_ITER_REACHED
};
-int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
-int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads);
+int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
+int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index 2bdc172..457216b 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -20,7 +20,7 @@
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
- double* alpha, double* beta, double *cost, int maxIter) {
+ double* alpha, double* beta, double *cost, uint64_t maxIter) {
// beware M and C are stored in row major C style!!!
using namespace lemon;
@@ -122,7 +122,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
- double* alpha, double* beta, double *cost, int maxIter, int numThreads) {
+ double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) {
// beware M and C are stored in row major C style!!!
using namespace lemon_omp;
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 42e08f4..e5cec89 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -14,13 +14,14 @@ from ..utils import dist
cimport cython
cimport libc.math as math
+from libc.stdint cimport uint64_t
import warnings
cdef extern from "EMD.h":
- int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil
- int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil
+ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
+ int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
@@ -39,7 +40,7 @@ def check_result(result_code):
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads):
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
@@ -75,7 +76,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
target histogram
M : (ns,nt) numpy.ndarray, float64
loss matrix
- max_iter : int
+ max_iter : uint64_t
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h
index 5b1038f..9612a8a 100644
--- a/ot/lp/network_simplex_simple.h
+++ b/ot/lp/network_simplex_simple.h
@@ -233,7 +233,7 @@ namespace lemon {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
- NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) :
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
@@ -242,7 +242,7 @@ namespace lemon {
{
// Reset data structures
reset();
- max_iter=maxiters;
+ max_iter = maxiters;
}
/// The type of the flow amounts, capacity bounds and supply values
@@ -293,7 +293,7 @@ namespace lemon {
private:
- size_t max_iter;
+ uint64_t max_iter;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
typedef std::vector<int> IntVector;
@@ -1427,7 +1427,7 @@ namespace lemon {
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;
- size_t iter_number=0;
+ uint64_t iter_number = 0;
//pivot.setDantzig(true);
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h
index dde84fd..5f19d73 100644
--- a/ot/lp/network_simplex_simple_omp.h
+++ b/ot/lp/network_simplex_simple_omp.h
@@ -244,7 +244,7 @@ namespace lemon_omp {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
- NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1) :
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters = 0, int numThreads=-1) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
@@ -317,7 +317,7 @@ namespace lemon_omp {
private:
- size_t max_iter;
+ uint64_t max_iter;
int num_threads;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
@@ -1563,7 +1563,7 @@ namespace lemon_omp {
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;
- size_t iter_number = 0;
+ uint64_t iter_number = 0;
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
if ((++iter_number <= max_iter&&max_iter > 0) || max_iter<=0) {
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index 02b3fc3..fc40df0 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -295,26 +295,27 @@ def test_mm_convergence(nx):
x = rng.randn(n, 2)
rng = np.random.RandomState(75)
y = rng.randn(n, 2)
- a = ot.utils.unif(n)
- b = ot.utils.unif(n)
+ a_np = ot.utils.unif(n)
+ b_np = ot.utils.unif(n)
M = ot.dist(x, y)
M = M / M.max()
reg_m = 100
- a, b, M = nx.from_numpy(a, b, M)
+ a, b, M = nx.from_numpy(a_np, b_np, M)
G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl',
- verbose=True, log=True)
- loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2(
- a, b, M, reg_m, div='kl', verbose=True))
+ verbose=False, log=True)
+ loss_kl = nx.to_numpy(
+ ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div='kl', verbose=True)
+ )
G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2',
verbose=False, log=True)
# check if the marginals come close to the true ones when large reg
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03)
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03)
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03)
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a_np, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b_np, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a_np, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b_np, atol=1e-03)
# check if mm_unbalanced2 returns the correct loss
np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl,
@@ -324,15 +325,16 @@ def test_mm_convergence(nx):
a_np, b_np = np.array([]), np.array([])
a, b = nx.from_numpy(a_np, b_np)
- G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl')
- G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2')
- np.testing.assert_allclose(G_kl_null, G_kl)
- np.testing.assert_allclose(G_l2_null, G_l2)
+ G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', verbose=False)
+ G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False)
+ np.testing.assert_allclose(nx.to_numpy(G_kl_null), nx.to_numpy(G_kl))
+ np.testing.assert_allclose(nx.to_numpy(G_l2_null), nx.to_numpy(G_l2))
# test when G0 is given
G0 = ot.emd(a, b, M)
+ G0_np = nx.to_numpy(G0)
reg_m = 10000
- G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0)
- G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0)
- np.testing.assert_allclose(G0, G_kl, atol=1e-05)
- np.testing.assert_allclose(G0, G_l2, atol=1e-05)
+ G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0, verbose=False)
+ G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0, verbose=False)
+ np.testing.assert_allclose(G0_np, nx.to_numpy(G_kl), atol=1e-05)
+ np.testing.assert_allclose(G0_np, nx.to_numpy(G_l2), atol=1e-05)