summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-24 14:49:14 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-24 14:49:14 +0200
commitf8e822c48eff02a3d65fc83d09dc0471bc9555aa (patch)
tree9bc2df7c16a30bdde3ca5a029b2f6b25e65dc771 /test/test_bregman.py
parenta31d3c2375ffec7eb3754ab4b66f75ce9a51eddd (diff)
test sinkhorn with empty marginals
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py31
1 files changed, 30 insertions, 1 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index fd2c972..b65de11 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -23,6 +23,33 @@ def test_sinkhorn():
assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+def test_sinkhorn_empty():
+ # test sinkhorn
+ n = 100
+ np.random.seed(0)
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G = ot.sinkhorn([], [], M, 1, stopThr=1e-10)
+ # check constratints
+ assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ G = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method='sinkhorn_stabilized')
+ # check constratints
+ assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ G = ot.sinkhorn(
+ [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling')
+ # check constratints
+ assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+
def test_sinkhorn_variants():
# test sinkhorn
n = 100
@@ -37,7 +64,9 @@ def test_sinkhorn_variants():
Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
Ges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
+ Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
- # check constratints
+ # check values
assert np.allclose(G0, Gs, atol=1e-05)
assert np.allclose(G0, Ges, atol=1e-05)
+ assert np.allclose(G0, Gerr)