summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-10-27 08:41:08 +0200
committerGitHub <noreply@github.com>2021-10-27 08:41:08 +0200
commitd7554331fc409fea48ee758fd630909dd9dc4827 (patch)
tree9b8ed4bf94c12d034d5fb1de5b7b5b76c23b4d05 /test/test_bregman.py
parent76450dddf8dd62b9714b72e99ae075516246d433 (diff)
[WIP] Sinkhorn in log space (#290)
* adda sinkhorn log and working sinkhorn2 function * more tests pass * more tests pass * it works but not by default yet * remove warningd * update circleci doc * update circleci doc * new sinkhorn implemeted but not by default * better * doctest pass * test doctest * new test utils * remove pep8 errors * remove pep8 errors * doc new implementtaion with log * test sinkhorn 2 * doc for log implementation
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py120
1 files changed, 111 insertions, 9 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 942cb6d..c1120ba 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -32,6 +32,27 @@ def test_sinkhorn():
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+def test_sinkhorn_multi_b():
+ # test sinkhorn
+ n = 10
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True)
+
+ loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)]
+ # check constraints
+ np.testing.assert_allclose(
+ loss0, loss, atol=1e-06) # cf convergence sinkhorn
+
+
def test_sinkhorn_backends(nx):
n_samples = 100
n_features = 2
@@ -147,6 +168,7 @@ def test_sinkhorn_variants(nx):
Mb = nx.from_numpy(M)
G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10))
Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
Ges = nx.to_numpy(ot.sinkhorn(
@@ -155,15 +177,73 @@ def test_sinkhorn_variants(nx):
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
- print(G0, G_green)
+
+
+@pytest.skip_backend("jax")
+def test_sinkhorn_variants_multi_b(nx):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ ub = nx.from_numpy(u)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M)
+
+ G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+
+ # check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+
+
+@pytest.skip_backend("jax")
+def test_sinkhorn2_variants_multi_b(nx):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ ub = nx.from_numpy(u)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M)
+
+ G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+
+ # check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
def test_sinkhorn_variants_log():
# test sinkhorn
- n = 100
+ n = 50
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -172,6 +252,7 @@ def test_sinkhorn_variants_log():
M = ot.dist(x, x)
G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
@@ -179,9 +260,30 @@ def test_sinkhorn_variants_log():
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Gl, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
- print(G0, G_green)
+
+
+def test_sinkhorn_variants_log_multib():
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
+ Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
+
+ # check values
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Gl, atol=1e-05)
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
@@ -326,10 +428,10 @@ def test_empirical_sinkhorn(nx):
a = ot.unif(n)
b = ot.unif(n)
- X_s = np.reshape(np.arange(n), (n, 1))
- X_t = np.reshape(np.arange(0, n), (n, 1))
+ X_s = np.reshape(1.0 * np.arange(n), (n, 1))
+ X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))
M = ot.dist(X_s, X_t)
- M_m = ot.dist(X_s, X_t, metric='minkowski')
+ M_m = ot.dist(X_s, X_t, metric='euclidean')
ab = nx.from_numpy(a)
bb = nx.from_numpy(b)
@@ -346,7 +448,7 @@ def test_empirical_sinkhorn(nx):
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski'))
+ G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean'))
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
@@ -378,7 +480,7 @@ def test_lazy_empirical_sinkhorn(nx):
X_s = np.reshape(np.arange(n), (n, 1))
X_t = np.reshape(np.arange(0, n), (n, 1))
M = ot.dist(X_s, X_t)
- M_m = ot.dist(X_s, X_t, metric='minkowski')
+ M_m = ot.dist(X_s, X_t, metric='euclidean')
ab = nx.from_numpy(a)
bb = nx.from_numpy(b)
@@ -398,7 +500,7 @@ def test_lazy_empirical_sinkhorn(nx):
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))