summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2019-12-19 07:46:47 +0100
committerGitHub <noreply@github.com>2019-12-19 07:46:47 +0100
commitc5039bcafde999114283f7e59fb03e176027d740 (patch)
treee4831b2dc8d8f1e58ed122b77d6537bb82408c00 /test
parentbbd8f2046ec42751eba8e5356366aded74a2930d (diff)
parente9954bbd959be41c59136cf921fbf094b127eb4e (diff)
Merge pull request #109 from rflamary/sparse_emd
[MRG] Sparse emd solution
Diffstat (limited to 'test')
-rw-r--r--test/test_ot.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index dacae0a..3dd544c 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -118,6 +118,28 @@ def test_emd_empty():
np.testing.assert_allclose(w, 0)
+def test_emd_sparse():
+
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ x2 = rng.randn(n, 2)
+
+ M = ot.dist(x, x2)
+
+ G = ot.emd([], [], M, dense=True)
+
+ Gs = ot.emd([], [], M, dense=False)
+
+ ws = ot.emd2([], [], M, dense=False)
+
+ # check G is the same
+ np.testing.assert_allclose(G, Gs.todense())
+ # check value
+ np.testing.assert_allclose(Gs.multiply(M).sum(), ws, rtol=1e-6)
+
+
def test_emd2_multi():
n = 500 # nb bins
@@ -149,7 +171,12 @@ def test_emd2_multi():
emdn = ot.emd2(a, b, M)
ot.toc('multi proc : {} s')
+ ot.tic()
+ emdn2 = ot.emd2(a, b, M, dense=False)
+ ot.toc('multi proc : {} s')
+
np.testing.assert_allclose(emd1, emdn)
+ np.testing.assert_allclose(emd1, emdn2, rtol=1e-6)
# emd loss multipro proc with log
ot.tic()