summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/lp/__init__.py2
-rw-r--r--test/test_ot.py38
2 files changed, 39 insertions, 1 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index cdd505d..4c968ca 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -656,7 +656,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
perm_a = np.argsort(x_a_1d)
perm_b = np.argsort(x_b_1d)
- G_sorted, indices, cost = emd_1d_sorted(a, b,
+ G_sorted, indices, cost = emd_1d_sorted(a[perm_a.flatten()], b[perm_b.flatten()],
x_a_1d[perm_a], x_b_1d[perm_b],
metric=metric, p=p)
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
diff --git a/test/test_ot.py b/test/test_ot.py
index 47df946..7afdae3 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -91,6 +91,44 @@ def test_emd_1d_emd2_1d():
with pytest.raises(AssertionError):
ot.emd_1d(u, v, [], [])
+def test_emd_1d_emd2_1d_with_weights():
+
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ w_u = rng.uniform(0., 1., n)
+ w_u = w_u / w_u.sum()
+
+ w_v = rng.uniform(0., 1., m)
+ w_v = w_v / w_v.sum()
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd(w_u, w_v, M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(w_u, G.sum(1))
+ np.testing.assert_allclose(w_v, G.sum(0))
+
+
+
def test_wass_1d():
# test emd1d gives similar results as emd