summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-24 12:09:15 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-24 12:09:15 +0200
commit75492827c89a47cbc6807d4859be178d255c49bc (patch)
tree197d5d420d71b265e6e0ce9134c572156db54ac0
parentff104a6dde2d652283f72d7901bbe79dfb8571ed (diff)
add test sinkhorn
-rw-r--r--Makefile3
-rw-r--r--ot/gpu/bregman.py2
-rw-r--r--test/test_ot.py46
3 files changed, 45 insertions, 6 deletions
diff --git a/Makefile b/Makefile
index cabe6a9..577bbbe 100644
--- a/Makefile
+++ b/Makefile
@@ -39,6 +39,9 @@ pep8 :
test : FORCE pep8
python -m py.test -v test/
+
+pytest : FORCE
+ python -m py.test -v test/
uploadpypi :
#python setup.py register
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index 7881c65..2302f80 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -9,7 +9,7 @@ import cudamat
def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
log=False, returnAsGPU=False):
- """
+ r"""
Solve the entropic regularization optimal transport problem on GPU
The function solves the following optimization problem:
diff --git a/test/test_ot.py b/test/test_ot.py
index 6976818..b69d080 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -18,8 +18,9 @@ def test_doctest():
def test_emd_emd2():
- # test emd
+ # test emd and emd2 for simple identity
n = 100
+ np.random.seed(0)
x = np.random.randn(n, 2)
u = ot.utils.unif(n)
@@ -35,14 +36,13 @@ def test_emd_emd2():
# check loss=0
assert np.allclose(w, 0)
-
-
-#@pytest.mark.skip(reason="Seems to be a conflict between pytest and multiprocessing")
+
def test_emd2_multi():
from ot.datasets import get_1D_gauss as gauss
n = 1000 # nb bins
+ np.random.seed(0)
# bin positions
x = np.arange(n, dtype=np.float64)
@@ -72,4 +72,40 @@ def test_emd2_multi():
emdn = ot.emd2(a, b, M)
ot.toc('multi proc : {} s')
- assert np.allclose(emd1, emdn)
+ assert np.allclose(emd1, emdn)
+
+
+def test_sinkhorn():
+ # test sinkhorn
+ n = 100
+ np.random.seed(0)
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G = ot.sinkhorn(u, u, M,1,stopThr=1e-10)
+
+ # check constratints
+ assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+def test_sinkhorn_variants():
+ # test sinkhorn
+ n = 100
+ np.random.seed(0)
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G0 = ot.sinkhorn(u, u, M,1, method='sinkhorn',stopThr=1e-10)
+ Gs = ot.sinkhorn(u, u, M,1, method='sinkhorn_stabilized',stopThr=1e-10)
+ Ges = ot.sinkhorn(u, u, M,1, method='sinkhorn_epsilon_scaling',stopThr=1e-10)
+
+ # check constratints
+ assert np.allclose(G0, Gs, atol=1e-05)
+ assert np.allclose(G0, Ges, atol=1e-05) #
+