summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/lp/__init__.py21
-rw-r--r--test/test_ot.py48
2 files changed, 51 insertions, 18 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index c4b5834..8d1baa0 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -10,16 +10,16 @@ Solvers for the original linear program OT problem
import multiprocessing
import sys
+
import numpy as np
from scipy.sparse import coo_matrix
-from .import cvx
-
+from . import cvx
+from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from ..utils import parmap
-from .cvx import barycenter
from ..utils import dist
+from ..utils import parmap
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -456,7 +456,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
return res
-def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
+def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
+ stopThr=1e-7, verbose=False, log=None):
"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
@@ -523,8 +524,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
T_sum = np.zeros((k, d))
- for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
-
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
+ weights.tolist()):
M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
@@ -649,12 +650,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
if b.ndim == 0 or len(b) == 0:
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
- x_a_1d = x_a.reshape((-1, ))
- x_b_1d = x_b.reshape((-1, ))
+ x_a_1d = x_a.reshape((-1,))
+ x_b_1d = x_b.reshape((-1,))
perm_a = np.argsort(x_a_1d)
perm_b = np.argsort(x_b_1d)
- G_sorted, indices, cost = emd_1d_sorted(a, b,
+ G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b],
x_a_1d[perm_a], x_b_1d[perm_b],
metric=metric, p=p)
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
diff --git a/test/test_ot.py b/test/test_ot.py
index 47df946..0f1357f 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -7,11 +7,11 @@
import warnings
import numpy as np
+import pytest
from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
-import pytest
def test_emd_dimension_mismatch():
@@ -75,12 +75,12 @@ def test_emd_1d_emd2_1d():
np.testing.assert_allclose(wass, wass1d_emd2)
# check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
np.testing.assert_allclose(wass_sp, wass1d_euc)
# check constraints
- np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
- np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
+ np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
# check G is similar
np.testing.assert_allclose(G, G_1d)
@@ -92,6 +92,42 @@ def test_emd_1d_emd2_1d():
ot.emd_1d(u, v, [], [])
+def test_emd_1d_emd2_1d_with_weights():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ w_u = rng.uniform(0., 1., n)
+ w_u = w_u / w_u.sum()
+
+ w_v = rng.uniform(0., 1., m)
+ w_v = w_v / w_v.sum()
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd(w_u, w_v, M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(w_u, G.sum(1))
+ np.testing.assert_allclose(w_v, G.sum(0))
+
+
def test_wass_1d():
# test emd1d gives similar results as emd
n = 20
@@ -135,7 +171,6 @@ def test_emd_empty():
def test_emd_sparse():
-
n = 100
rng = np.random.RandomState(0)
@@ -211,7 +246,6 @@ def test_emd2_multi():
def test_lp_barycenter():
-
a1 = np.array([1.0, 0, 0])[:, None]
a2 = np.array([0, 0, 1.0])[:, None]
@@ -228,7 +262,6 @@ def test_lp_barycenter():
def test_free_support_barycenter():
-
measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
measures_weights = [np.array([1.]), np.array([1.])]
@@ -244,7 +277,6 @@ def test_free_support_barycenter():
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
-
a1 = np.array([1.0, 0, 0])[:, None]
a2 = np.array([0, 0, 1.0])[:, None]