From 57321bd0172c97b77dfc8b14972c18d063b6dda8 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 11:13:07 +0100 Subject: add awesome sparse solver --- test/test_ot.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index dacae0a..4d59e12 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -118,6 +118,26 @@ def test_emd_empty(): np.testing.assert_allclose(w, 0) +def test_emd_sparse(): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + x2 = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x2) + + G = ot.emd([], [], M) + + Gs = ot.emd([], [], M, sparse=True) + + # check G is the same + np.testing.assert_allclose(G, Gs.todense()) + # check constraints + + def test_emd2_multi(): n = 500 # nb bins -- cgit v1.2.3 From a6a654de5e78dd388a793fbd26f60045b05d519c Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 11:31:32 +0100 Subject: proper documentation and parameter --- ot/lp/EMD.h | 2 +- ot/lp/EMD_wrapper.cpp | 3 ++- ot/lp/__init__.py | 16 ++++++++++++++-- ot/lp/emd_wrap.pyx | 10 ++++++---- test/test_ot.py | 2 +- 5 files changed, 24 insertions(+), 9 deletions(-) (limited to 'test/test_ot.py') diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index bc513d2..9896091 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -33,7 +33,7 @@ enum ProblemType { int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter); int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, - long *iG, long *jG, double *G, + long *iG, long *jG, double *G, long * nG, double* alpha, double* beta, double *cost, int maxIter); #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 2aa44c1..9be2cdc 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -108,7 +108,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, - long *iG, long *jG, double *G, + long *iG, long *jG, double *G, long * nG, double* alpha, double* beta, double *cost, int maxIter) { // beware M and C anre strored in row major C style!!! @@ -202,6 +202,7 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, cur++; } } + *nG=cur; // nb of value +1 for numpy indexing } diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 4fec7d9..d476071 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -27,7 +27,7 @@ __all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] -def emd(a, b, M, numItermax=100000, log=False, sparse=False): +def emd(a, b, M, numItermax=100000, log=False, dense=True): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -62,6 +62,10 @@ def emd(a, b, M, numItermax=100000, log=False, sparse=False): log: bool, optional (default=False) If True, returns a dictionary containing the cost and dual variables. Otherwise returns only the optimal transportation matrix. + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Returns ------- @@ -103,6 +107,8 @@ def emd(a, b, M, numItermax=100000, log=False, sparse=False): b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) + sparse= not dense + # if empty array given then use uniform distributions if len(a) == 0: a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] @@ -128,7 +134,7 @@ def emd(a, b, M, numItermax=100000, log=False, sparse=False): def emd2(a, b, M, processes=multiprocessing.cpu_count(), - numItermax=100000, log=False, sparse=False, return_matrix=False): + numItermax=100000, log=False, dense=True, return_matrix=False): r"""Solves the Earth Movers distance problem and returns the loss .. math:: @@ -166,6 +172,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), variables. Otherwise returns only the optimal transportation cost. return_matrix: boolean, optional (default=False) If True, returns the optimal transportation matrix in the log. + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Returns ------- @@ -207,6 +217,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) + sparse=not dense + # problem with pikling Forks if sys.platform.endswith('win32'): processes=1 diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index f183995..4b6cdce 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -21,7 +21,7 @@ import warnings cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, - long *iG, long *jG, double *G, + long *iG, long *jG, double *G, long * nG, double* alpha, double* beta, double *cost, int maxIter) cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -75,7 +75,8 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod max_iter : int The maximum number of iterations before stopping the optimization algorithm if it has not converged. - + sparse : bool + Returning a sparse transport matrix if set to True Returns ------- @@ -87,6 +88,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod cdef int n2= M.shape[1] cdef int nmax=n1+n2-1 cdef int result_code = 0 + cdef int nG=0 cdef double cost=0 cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) @@ -111,10 +113,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod jG=np.zeros(nmax,dtype=np.int) - result_code = EMD_wrap_return_sparse(n1, n2, a.data, b.data, M.data, iG.data, jG.data, Gv.data, alpha.data, beta.data, &cost, max_iter) + result_code = EMD_wrap_return_sparse(n1, n2, a.data, b.data, M.data, iG.data, jG.data, Gv.data, &nG, alpha.data, beta.data, &cost, max_iter) - return Gv, iG, jG, cost, alpha, beta, result_code + return Gv[:nG], iG[:nG], jG[:nG], cost, alpha, beta, result_code else: diff --git a/test/test_ot.py b/test/test_ot.py index 4d59e12..7b44fd1 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -131,7 +131,7 @@ def test_emd_sparse(): G = ot.emd([], [], M) - Gs = ot.emd([], [], M, sparse=True) + Gs = ot.emd([], [], M, dense=False) # check G is the same np.testing.assert_allclose(G, Gs.todense()) -- cgit v1.2.3 From 127adbaf4eef7a6dffbdcd4f930fc6301587f861 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 11:41:13 +0100 Subject: remove useless variable --- test/test_ot.py | 1 - 1 file changed, 1 deletion(-) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index 7b44fd1..8602022 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -125,7 +125,6 @@ def test_emd_sparse(): x = rng.randn(n, 2) x2 = rng.randn(n, 2) - u = ot.utils.unif(n) M = ot.dist(x, x2) -- cgit v1.2.3 From 84384dd9e5dc78ed5cc867a53bd1de31c05d77fc Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 13:34:05 +0100 Subject: add test emd2 --- test/test_ot.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index 8602022..507d188 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -132,9 +132,12 @@ def test_emd_sparse(): Gs = ot.emd([], [], M, dense=False) + ws = ot.emd2([], [], M, dense=False) + # check G is the same np.testing.assert_allclose(G, Gs.todense()) - # check constraints + # check value + np.testing.assert_allclose(Gs.multiply(M).sum(), ws, rtol=1e-6) def test_emd2_multi(): -- cgit v1.2.3 From 7371b2f4f931db8f67ec2967253be8d95ff9fe80 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 13:34:55 +0100 Subject: add test emd2 --- test/test_ot.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index 507d188..48ea87f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -171,7 +171,12 @@ def test_emd2_multi(): emdn = ot.emd2(a, b, M) ot.toc('multi proc : {} s') + ot.tic() + emdn2 = ot.emd2(a, b, M, dense = False) + ot.toc('multi proc : {} s') + np.testing.assert_allclose(emd1, emdn) + np.testing.assert_allclose(emd1, emdn2) # emd loss multipro proc with log ot.tic() -- cgit v1.2.3 From dfaba55affcca606e8e041bdbd0fc5a7735c2b07 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 13:36:08 +0100 Subject: add test emd2 multi --- test/test_ot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index 48ea87f..470fd0f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -176,7 +176,7 @@ def test_emd2_multi(): ot.toc('multi proc : {} s') np.testing.assert_allclose(emd1, emdn) - np.testing.assert_allclose(emd1, emdn2) + np.testing.assert_allclose(emd1, emdn2, rtol=1e-6) # emd loss multipro proc with log ot.tic() -- cgit v1.2.3 From c439e3efb920086154c741b41f65d99165e875d8 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 13:57:13 +0100 Subject: pep8 --- test/test_ot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index 470fd0f..fbacd8b 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -172,8 +172,8 @@ def test_emd2_multi(): ot.toc('multi proc : {} s') ot.tic() - emdn2 = ot.emd2(a, b, M, dense = False) - ot.toc('multi proc : {} s') + emdn2 = ot.emd2(a, b, M, dense=False) + ot.toc('multi proc : {} s') np.testing.assert_allclose(emd1, emdn) np.testing.assert_allclose(emd1, emdn2, rtol=1e-6) -- cgit v1.2.3 From d97f81dd731c4b1132939500076fd48c89f19d1f Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 18 Dec 2019 10:17:31 +0100 Subject: update test --- test/test_ot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index fbacd8b..3dd544c 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -128,7 +128,7 @@ def test_emd_sparse(): M = ot.dist(x, x2) - G = ot.emd([], [], M) + G = ot.emd([], [], M, dense=True) Gs = ot.emd([], [], M, dense=False) -- cgit v1.2.3