summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-20 14:29:56 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-20 14:29:56 +0200
commitf63f34f8adb6943b6410f8b773b4b4d8f1c7b4ba (patch)
tree96dd2a29842c86a3e3875feba1e8fa8ad3076eb7 /test/test_ot.py
parent5a6b226de20624b51c2ff98bc30e5611a7a788c7 (diff)
EMD 1d without doc
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 7652394..7008002 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -46,6 +46,32 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)
+def test_emd1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ u = np.random.randn(n, 1)
+ v = np.random.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+
+ # check G is similar
+ np.testing.assert_allclose(G, G_1d)
+
+ # check AssertionError is raised if called on non 1d arrays
+ u = np.random.randn(n, 2)
+ v = np.random.randn(m, 2)
+ np.testing.assert_raises(AssertionError, ot.emd_1d, [], [], u, v)
+
+
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100