diff options
author | Romain Tavenard <romain.tavenard@univ-rennes2.fr> | 2019-06-20 14:29:56 +0200 |
---|---|---|
committer | Romain Tavenard <romain.tavenard@univ-rennes2.fr> | 2019-06-20 14:29:56 +0200 |
commit | f63f34f8adb6943b6410f8b773b4b4d8f1c7b4ba (patch) | |
tree | 96dd2a29842c86a3e3875feba1e8fa8ad3076eb7 /test | |
parent | 5a6b226de20624b51c2ff98bc30e5611a7a788c7 (diff) |
EMD 1d without doc
Diffstat (limited to 'test')
-rw-r--r-- | test/test_ot.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 7652394..7008002 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -46,6 +46,32 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) +def test_emd1d(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + u = np.random.randn(n, 1) + v = np.random.randn(m, 1) + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd([], [], M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True) + wass1d = log["cost"] + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + + # check G is similar + np.testing.assert_allclose(G, G_1d) + + # check AssertionError is raised if called on non 1d arrays + u = np.random.randn(n, 2) + v = np.random.randn(m, 2) + np.testing.assert_raises(AssertionError, ot.emd_1d, [], [], u, v) + + def test_emd_empty(): # test emd and emd2 for simple identity n = 100 |