From 0d333e004636f5d25edea6bb195e8e4d9a95ba98 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Thu, 27 Jun 2019 10:23:32 +0200 Subject: Improved tests and docs for wasserstein_1d --- test/test_ot.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index 6d6ea26..48423e7 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -85,6 +85,29 @@ def test_emd_1d_emd2_1d(): np.testing.assert_raises(AssertionError, ot.emd_1d, u, v, [], []) +def test_wass_1d(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd([], [], M, log=True) + wass = log["cost"] + + G_1d, log = ot.wasserstein_1d(u, v, [], [], p=2., log=True) + wass1d = log["cost"] + + # check loss is similar + np.testing.assert_allclose(np.sqrt(wass), wass1d) + + # check G is similar + np.testing.assert_allclose(G, G_1d) + + def test_emd_empty(): # test emd and emd2 for simple identity n = 100 -- cgit v1.2.3