From 75492827c89a47cbc6807d4859be178d255c49bc Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 24 Jul 2017 12:09:15 +0200 Subject: add test sinkhorn --- test/test_ot.py | 46 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 41 insertions(+), 5 deletions(-) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index 6976818..b69d080 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -18,8 +18,9 @@ def test_doctest(): def test_emd_emd2(): - # test emd + # test emd and emd2 for simple identity n = 100 + np.random.seed(0) x = np.random.randn(n, 2) u = ot.utils.unif(n) @@ -35,14 +36,13 @@ def test_emd_emd2(): # check loss=0 assert np.allclose(w, 0) - - -#@pytest.mark.skip(reason="Seems to be a conflict between pytest and multiprocessing") + def test_emd2_multi(): from ot.datasets import get_1D_gauss as gauss n = 1000 # nb bins + np.random.seed(0) # bin positions x = np.arange(n, dtype=np.float64) @@ -72,4 +72,40 @@ def test_emd2_multi(): emdn = ot.emd2(a, b, M) ot.toc('multi proc : {} s') - assert np.allclose(emd1, emdn) + assert np.allclose(emd1, emdn) + + +def test_sinkhorn(): + # test sinkhorn + n = 100 + np.random.seed(0) + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G = ot.sinkhorn(u, u, M,1,stopThr=1e-10) + + # check constratints + assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + +def test_sinkhorn_variants(): + # test sinkhorn + n = 100 + np.random.seed(0) + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G0 = ot.sinkhorn(u, u, M,1, method='sinkhorn',stopThr=1e-10) + Gs = ot.sinkhorn(u, u, M,1, method='sinkhorn_stabilized',stopThr=1e-10) + Ges = ot.sinkhorn(u, u, M,1, method='sinkhorn_epsilon_scaling',stopThr=1e-10) + + # check constratints + assert np.allclose(G0, Gs, atol=1e-05) + assert np.allclose(G0, Ges, atol=1e-05) # + -- cgit v1.2.3