summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-27 10:23:32 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-27 10:23:32 +0200
commit0d333e004636f5d25edea6bb195e8e4d9a95ba98 (patch)
tree75b40d64101c0f503e6ed8d3101db1c095ffbd72 /test/test_ot.py
parent1140141938c678d267f688dbb9106d3422d633c5 (diff)
Improved tests and docs for wasserstein_1d
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 6d6ea26..48423e7 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -85,6 +85,29 @@ def test_emd_1d_emd2_1d():
np.testing.assert_raises(AssertionError, ot.emd_1d, u, v, [], [])
+def test_wass_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+
+ G_1d, log = ot.wasserstein_1d(u, v, [], [], p=2., log=True)
+ wass1d = log["cost"]
+
+ # check loss is similar
+ np.testing.assert_allclose(np.sqrt(wass), wass1d)
+
+ # check G is similar
+ np.testing.assert_allclose(G, G_1d)
+
+
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100