diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-24 12:09:15 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-24 12:09:15 +0200 |
commit | 75492827c89a47cbc6807d4859be178d255c49bc (patch) | |
tree | 197d5d420d71b265e6e0ce9134c572156db54ac0 | |
parent | ff104a6dde2d652283f72d7901bbe79dfb8571ed (diff) |
add test sinkhorn
-rw-r--r-- | Makefile | 3 | ||||
-rw-r--r-- | ot/gpu/bregman.py | 2 | ||||
-rw-r--r-- | test/test_ot.py | 46 |
3 files changed, 45 insertions, 6 deletions
@@ -39,6 +39,9 @@ pep8 : test : FORCE pep8 python -m py.test -v test/ + +pytest : FORCE + python -m py.test -v test/ uploadpypi : #python setup.py register diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index 7881c65..2302f80 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -9,7 +9,7 @@ import cudamat def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, returnAsGPU=False): - """ + r""" Solve the entropic regularization optimal transport problem on GPU The function solves the following optimization problem: diff --git a/test/test_ot.py b/test/test_ot.py index 6976818..b69d080 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -18,8 +18,9 @@ def test_doctest(): def test_emd_emd2(): - # test emd + # test emd and emd2 for simple identity n = 100 + np.random.seed(0) x = np.random.randn(n, 2) u = ot.utils.unif(n) @@ -35,14 +36,13 @@ def test_emd_emd2(): # check loss=0 assert np.allclose(w, 0) - - -#@pytest.mark.skip(reason="Seems to be a conflict between pytest and multiprocessing") + def test_emd2_multi(): from ot.datasets import get_1D_gauss as gauss n = 1000 # nb bins + np.random.seed(0) # bin positions x = np.arange(n, dtype=np.float64) @@ -72,4 +72,40 @@ def test_emd2_multi(): emdn = ot.emd2(a, b, M) ot.toc('multi proc : {} s') - assert np.allclose(emd1, emdn) + assert np.allclose(emd1, emdn) + + +def test_sinkhorn(): + # test sinkhorn + n = 100 + np.random.seed(0) + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G = ot.sinkhorn(u, u, M,1,stopThr=1e-10) + + # check constratints + assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + +def test_sinkhorn_variants(): + # test sinkhorn + n = 100 + np.random.seed(0) + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G0 = ot.sinkhorn(u, u, M,1, method='sinkhorn',stopThr=1e-10) + Gs = ot.sinkhorn(u, u, M,1, method='sinkhorn_stabilized',stopThr=1e-10) + Ges = ot.sinkhorn(u, u, M,1, method='sinkhorn_epsilon_scaling',stopThr=1e-10) + + # check constratints + assert np.allclose(G0, Gs, atol=1e-05) + assert np.allclose(G0, Ges, atol=1e-05) # + |