summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-24 12:09:15 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-24 12:09:15 +0200
commit75492827c89a47cbc6807d4859be178d255c49bc (patch)
tree197d5d420d71b265e6e0ce9134c572156db54ac0 /test/test_ot.py
parentff104a6dde2d652283f72d7901bbe79dfb8571ed (diff)
add test sinkhorn
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py46
1 files changed, 41 insertions, 5 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 6976818..b69d080 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -18,8 +18,9 @@ def test_doctest():
def test_emd_emd2():
- # test emd
+ # test emd and emd2 for simple identity
n = 100
+ np.random.seed(0)
x = np.random.randn(n, 2)
u = ot.utils.unif(n)
@@ -35,14 +36,13 @@ def test_emd_emd2():
# check loss=0
assert np.allclose(w, 0)
-
-
-#@pytest.mark.skip(reason="Seems to be a conflict between pytest and multiprocessing")
+
def test_emd2_multi():
from ot.datasets import get_1D_gauss as gauss
n = 1000 # nb bins
+ np.random.seed(0)
# bin positions
x = np.arange(n, dtype=np.float64)
@@ -72,4 +72,40 @@ def test_emd2_multi():
emdn = ot.emd2(a, b, M)
ot.toc('multi proc : {} s')
- assert np.allclose(emd1, emdn)
+ assert np.allclose(emd1, emdn)
+
+
+def test_sinkhorn():
+ # test sinkhorn
+ n = 100
+ np.random.seed(0)
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G = ot.sinkhorn(u, u, M,1,stopThr=1e-10)
+
+ # check constratints
+ assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+def test_sinkhorn_variants():
+ # test sinkhorn
+ n = 100
+ np.random.seed(0)
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G0 = ot.sinkhorn(u, u, M,1, method='sinkhorn',stopThr=1e-10)
+ Gs = ot.sinkhorn(u, u, M,1, method='sinkhorn_stabilized',stopThr=1e-10)
+ Ges = ot.sinkhorn(u, u, M,1, method='sinkhorn_epsilon_scaling',stopThr=1e-10)
+
+ # check constratints
+ assert np.allclose(G0, Gs, atol=1e-05)
+ assert np.allclose(G0, Ges, atol=1e-05) #
+