From f8e822c48eff02a3d65fc83d09dc0471bc9555aa Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 24 Jul 2017 14:49:14 +0200 Subject: test sinkhorn with empty marginals --- test/test_bregman.py | 31 ++++++++++++++++++++++++++++++- test/test_ot.py | 23 +++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index fd2c972..b65de11 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -23,6 +23,33 @@ def test_sinkhorn(): assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn +def test_sinkhorn_empty(): + # test sinkhorn + n = 100 + np.random.seed(0) + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G = ot.sinkhorn([], [], M, 1, stopThr=1e-10) + # check constratints + assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + + G = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method='sinkhorn_stabilized') + # check constratints + assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + + G = ot.sinkhorn( + [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling') + # check constratints + assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + + def test_sinkhorn_variants(): # test sinkhorn n = 100 @@ -37,7 +64,9 @@ def test_sinkhorn_variants(): Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10) Ges = ot.sinkhorn( u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10) + Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10) - # check constratints + # check values assert np.allclose(G0, Gs, atol=1e-05) assert np.allclose(G0, Ges, atol=1e-05) + assert np.allclose(G0, Gerr) diff --git a/test/test_ot.py b/test/test_ot.py index 16fd510..3897397 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -40,6 +40,29 @@ def test_emd_emd2(): assert np.allclose(w, 0) +def test_emd_empty(): + # test emd and emd2 for simple identity + n = 100 + np.random.seed(0) + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G = ot.emd([], [], M) + + # check G is identity + assert np.allclose(G, np.eye(n) / n) + # check constratints + assert np.allclose(u, G.sum(1)) # cf convergence sinkhorn + assert np.allclose(u, G.sum(0)) # cf convergence sinkhorn + + w = ot.emd2([], [], M) + # check loss=0 + assert np.allclose(w, 0) + + def test_emd2_multi(): from ot.datasets import get_1D_gauss as gauss -- cgit v1.2.3