From 592f933085d5b521a440eb91eccc283c43732170 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Wed, 1 Apr 2020 12:14:42 +0100 Subject: Fix ordering --- test/test_ot.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index 47df946..7afdae3 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -91,6 +91,44 @@ def test_emd_1d_emd2_1d(): with pytest.raises(AssertionError): ot.emd_1d(u, v, [], []) +def test_emd_1d_emd2_1d_with_weights(): + + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd(w_u, w_v, M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(w_u, G.sum(1)) + np.testing.assert_allclose(w_v, G.sum(0)) + + + def test_wass_1d(): # test emd1d gives similar results as emd -- cgit v1.2.3