summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-24 12:13:37 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-24 12:13:37 +0200
commitd9205c886219d5410bc4705b46d9f14710c81ddd (patch)
tree31d2d8137d7a0d874b2fef28fb59f77da08124fc /test/test_ot.py
parenta8a0995edefd437f56b91b95c2628fb031428a08 (diff)
clean tests
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py37
1 files changed, 0 insertions, 37 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 9103ac8..3fa1bc4 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -74,40 +74,3 @@ def test_emd2_multi():
ot.toc('multi proc : {} s')
assert np.allclose(emd1, emdn)
-
-
-def test_sinkhorn():
- # test sinkhorn
- n = 100
- np.random.seed(0)
-
- x = np.random.randn(n, 2)
- u = ot.utils.unif(n)
-
- M = ot.dist(x, x)
-
- G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
-
- # check constratints
- assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
- assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
-
-
-def test_sinkhorn_variants():
- # test sinkhorn
- n = 100
- np.random.seed(0)
-
- x = np.random.randn(n, 2)
- u = ot.utils.unif(n)
-
- M = ot.dist(x, x)
-
- G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
- Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
- Ges = ot.sinkhorn(
- u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
-
- # check constratints
- assert np.allclose(G0, Gs, atol=1e-05)
- assert np.allclose(G0, Ges, atol=1e-05)