diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2019-12-19 07:46:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-12-19 07:46:47 +0100 |
commit | c5039bcafde999114283f7e59fb03e176027d740 (patch) | |
tree | e4831b2dc8d8f1e58ed122b77d6537bb82408c00 /test/test_ot.py | |
parent | bbd8f2046ec42751eba8e5356366aded74a2930d (diff) | |
parent | e9954bbd959be41c59136cf921fbf094b127eb4e (diff) |
Merge pull request #109 from rflamary/sparse_emd
[MRG] Sparse emd solution
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index dacae0a..3dd544c 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -118,6 +118,28 @@ def test_emd_empty(): np.testing.assert_allclose(w, 0) +def test_emd_sparse(): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + x2 = rng.randn(n, 2) + + M = ot.dist(x, x2) + + G = ot.emd([], [], M, dense=True) + + Gs = ot.emd([], [], M, dense=False) + + ws = ot.emd2([], [], M, dense=False) + + # check G is the same + np.testing.assert_allclose(G, Gs.todense()) + # check value + np.testing.assert_allclose(Gs.multiply(M).sum(), ws, rtol=1e-6) + + def test_emd2_multi(): n = 500 # nb bins @@ -149,7 +171,12 @@ def test_emd2_multi(): emdn = ot.emd2(a, b, M) ot.toc('multi proc : {} s') + ot.tic() + emdn2 = ot.emd2(a, b, M, dense=False) + ot.toc('multi proc : {} s') + np.testing.assert_allclose(emd1, emdn) + np.testing.assert_allclose(emd1, emdn2, rtol=1e-6) # emd loss multipro proc with log ot.tic() |