diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_ot.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index dacae0a..3dd544c 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -118,6 +118,28 @@ def test_emd_empty(): np.testing.assert_allclose(w, 0) +def test_emd_sparse(): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + x2 = rng.randn(n, 2) + + M = ot.dist(x, x2) + + G = ot.emd([], [], M, dense=True) + + Gs = ot.emd([], [], M, dense=False) + + ws = ot.emd2([], [], M, dense=False) + + # check G is the same + np.testing.assert_allclose(G, Gs.todense()) + # check value + np.testing.assert_allclose(Gs.multiply(M).sum(), ws, rtol=1e-6) + + def test_emd2_multi(): n = 500 # nb bins @@ -149,7 +171,12 @@ def test_emd2_multi(): emdn = ot.emd2(a, b, M) ot.toc('multi proc : {} s') + ot.tic() + emdn2 = ot.emd2(a, b, M, dense=False) + ot.toc('multi proc : {} s') + np.testing.assert_allclose(emd1, emdn) + np.testing.assert_allclose(emd1, emdn2, rtol=1e-6) # emd loss multipro proc with log ot.tic() |