summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-24 14:49:14 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-24 14:49:14 +0200
commitf8e822c48eff02a3d65fc83d09dc0471bc9555aa (patch)
tree9bc2df7c16a30bdde3ca5a029b2f6b25e65dc771 /test
parenta31d3c2375ffec7eb3754ab4b66f75ce9a51eddd (diff)
test sinkhorn with empty marginals
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py31
-rw-r--r--test/test_ot.py23
2 files changed, 53 insertions, 1 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index fd2c972..b65de11 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -23,6 +23,33 @@ def test_sinkhorn():
assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+def test_sinkhorn_empty():
+ # test sinkhorn
+ n = 100
+ np.random.seed(0)
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G = ot.sinkhorn([], [], M, 1, stopThr=1e-10)
+ # check constratints
+ assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ G = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method='sinkhorn_stabilized')
+ # check constratints
+ assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ G = ot.sinkhorn(
+ [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling')
+ # check constratints
+ assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+
def test_sinkhorn_variants():
# test sinkhorn
n = 100
@@ -37,7 +64,9 @@ def test_sinkhorn_variants():
Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
Ges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
+ Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
- # check constratints
+ # check values
assert np.allclose(G0, Gs, atol=1e-05)
assert np.allclose(G0, Ges, atol=1e-05)
+ assert np.allclose(G0, Gerr)
diff --git a/test/test_ot.py b/test/test_ot.py
index 16fd510..3897397 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -40,6 +40,29 @@ def test_emd_emd2():
assert np.allclose(w, 0)
+def test_emd_empty():
+ # test emd and emd2 for simple identity
+ n = 100
+ np.random.seed(0)
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G = ot.emd([], [], M)
+
+ # check G is identity
+ assert np.allclose(G, np.eye(n) / n)
+ # check constratints
+ assert np.allclose(u, G.sum(1)) # cf convergence sinkhorn
+ assert np.allclose(u, G.sum(0)) # cf convergence sinkhorn
+
+ w = ot.emd2([], [], M)
+ # check loss=0
+ assert np.allclose(w, 0)
+
+
def test_emd2_multi():
from ot.datasets import get_1D_gauss as gauss