diff options
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 16fd510..3897397 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -40,6 +40,29 @@ def test_emd_emd2(): assert np.allclose(w, 0) +def test_emd_empty(): + # test emd and emd2 for simple identity + n = 100 + np.random.seed(0) + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G = ot.emd([], [], M) + + # check G is identity + assert np.allclose(G, np.eye(n) / n) + # check constratints + assert np.allclose(u, G.sum(1)) # cf convergence sinkhorn + assert np.allclose(u, G.sum(0)) # cf convergence sinkhorn + + w = ot.emd2([], [], M) + # check loss=0 + assert np.allclose(w, 0) + + def test_emd2_multi(): from ot.datasets import get_1D_gauss as gauss |