From 592f933085d5b521a440eb91eccc283c43732170 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Wed, 1 Apr 2020 12:14:42 +0100 Subject: Fix ordering --- ot/lp/__init__.py | 2 +- test/test_ot.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index cdd505d..4c968ca 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -656,7 +656,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, perm_a = np.argsort(x_a_1d) perm_b = np.argsort(x_b_1d) - G_sorted, indices, cost = emd_1d_sorted(a, b, + G_sorted, indices, cost = emd_1d_sorted(a[perm_a.flatten()], b[perm_b.flatten()], x_a_1d[perm_a], x_b_1d[perm_b], metric=metric, p=p) G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])), diff --git a/test/test_ot.py b/test/test_ot.py index 47df946..7afdae3 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -91,6 +91,44 @@ def test_emd_1d_emd2_1d(): with pytest.raises(AssertionError): ot.emd_1d(u, v, [], []) +def test_emd_1d_emd2_1d_with_weights(): + + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd(w_u, w_v, M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(w_u, G.sum(1)) + np.testing.assert_allclose(w_v, G.sum(0)) + + + def test_wass_1d(): # test emd1d gives similar results as emd -- cgit v1.2.3 From 1e2e118e3a30224932ed2f012bb8f9f0f374ef2c Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Thu, 2 Apr 2020 10:39:55 +0100 Subject: Fix test --- test/test_ot.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/test/test_ot.py b/test/test_ot.py index 7afdae3..0f1357f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -7,11 +7,11 @@ import warnings import numpy as np +import pytest from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss -import pytest def test_emd_dimension_mismatch(): @@ -75,12 +75,12 @@ def test_emd_1d_emd2_1d(): np.testing.assert_allclose(wass, wass1d_emd2) # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, ))) + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) np.testing.assert_allclose(wass_sp, wass1d_euc) # check constraints - np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1)) - np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0)) + np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) + np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) # check G is similar np.testing.assert_allclose(G, G_1d) @@ -91,8 +91,8 @@ def test_emd_1d_emd2_1d(): with pytest.raises(AssertionError): ot.emd_1d(u, v, [], []) -def test_emd_1d_emd2_1d_with_weights(): +def test_emd_1d_emd2_1d_with_weights(): # test emd1d gives similar results as emd n = 20 m = 30 @@ -120,7 +120,7 @@ def test_emd_1d_emd2_1d_with_weights(): np.testing.assert_allclose(wass, wass1d_emd2) # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v) np.testing.assert_allclose(wass_sp, wass1d_euc) # check constraints @@ -128,8 +128,6 @@ def test_emd_1d_emd2_1d_with_weights(): np.testing.assert_allclose(w_v, G.sum(0)) - - def test_wass_1d(): # test emd1d gives similar results as emd n = 20 @@ -173,7 +171,6 @@ def test_emd_empty(): def test_emd_sparse(): - n = 100 rng = np.random.RandomState(0) @@ -249,7 +246,6 @@ def test_emd2_multi(): def test_lp_barycenter(): - a1 = np.array([1.0, 0, 0])[:, None] a2 = np.array([0, 0, 1.0])[:, None] @@ -266,7 +262,6 @@ def test_lp_barycenter(): def test_free_support_barycenter(): - measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] measures_weights = [np.array([1.]), np.array([1.])] @@ -282,7 +277,6 @@ def test_free_support_barycenter(): @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): - a1 = np.array([1.0, 0, 0])[:, None] a2 = np.array([0, 0, 1.0])[:, None] -- cgit v1.2.3 From 60943d00bab1682d6fac22b1e1ba5e64569b4e78 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Thu, 2 Apr 2020 10:41:24 +0100 Subject: Auto PEP8 --- ot/lp/__init__.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 4c968ca..1922785 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -12,16 +12,16 @@ Solvers for the original linear program OT problem import multiprocessing import sys + import numpy as np from scipy.sparse import coo_matrix -from .import cvx - +from . import cvx +from .cvx import barycenter # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from ..utils import parmap -from .cvx import barycenter from ..utils import dist +from ..utils import parmap __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] @@ -458,7 +458,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), return res -def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None): +def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, + stopThr=1e-7, verbose=False, log=None): """ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) @@ -525,8 +526,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None T_sum = np.zeros((k, d)) - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): - + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, + weights.tolist()): M_i = dist(X, measure_locations_i) T_i = emd(b, measure_weights_i, M_i) T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) @@ -651,8 +652,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, if b.ndim == 0 or len(b) == 0: b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] - x_a_1d = x_a.reshape((-1, )) - x_b_1d = x_b.reshape((-1, )) + x_a_1d = x_a.reshape((-1,)) + x_b_1d = x_b.reshape((-1,)) perm_a = np.argsort(x_a_1d) perm_b = np.argsort(x_b_1d) -- cgit v1.2.3 From a9e69509412338920142c0615a50bc00739144d0 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Thu, 2 Apr 2020 11:11:16 +0100 Subject: Remove flatten, it's not useful. --- ot/lp/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 1922785..f4f6861 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -657,7 +657,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, perm_a = np.argsort(x_a_1d) perm_b = np.argsort(x_b_1d) - G_sorted, indices, cost = emd_1d_sorted(a[perm_a.flatten()], b[perm_b.flatten()], + G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b], x_a_1d[perm_a], x_b_1d[perm_b], metric=metric, p=p) G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])), -- cgit v1.2.3