summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-10-27 08:41:08 +0200
committerGitHub <noreply@github.com>2021-10-27 08:41:08 +0200
commitd7554331fc409fea48ee758fd630909dd9dc4827 (patch)
tree9b8ed4bf94c12d034d5fb1de5b7b5b76c23b4d05 /test
parent76450dddf8dd62b9714b72e99ae075516246d433 (diff)
[WIP] Sinkhorn in log space (#290)
* adda sinkhorn log and working sinkhorn2 function * more tests pass * more tests pass * it works but not by default yet * remove warningd * update circleci doc * update circleci doc * new sinkhorn implemeted but not by default * better * doctest pass * test doctest * new test utils * remove pep8 errors * remove pep8 errors * doc new implementtaion with log * test sinkhorn 2 * doc for log implementation
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py120
-rw-r--r--test/test_gromov.py10
-rw-r--r--test/test_helpers.py4
-rw-r--r--test/test_utils.py15
4 files changed, 133 insertions, 16 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 942cb6d..c1120ba 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -32,6 +32,27 @@ def test_sinkhorn():
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+def test_sinkhorn_multi_b():
+ # test sinkhorn
+ n = 10
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True)
+
+ loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)]
+ # check constraints
+ np.testing.assert_allclose(
+ loss0, loss, atol=1e-06) # cf convergence sinkhorn
+
+
def test_sinkhorn_backends(nx):
n_samples = 100
n_features = 2
@@ -147,6 +168,7 @@ def test_sinkhorn_variants(nx):
Mb = nx.from_numpy(M)
G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10))
Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
Ges = nx.to_numpy(ot.sinkhorn(
@@ -155,15 +177,73 @@ def test_sinkhorn_variants(nx):
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
- print(G0, G_green)
+
+
+@pytest.skip_backend("jax")
+def test_sinkhorn_variants_multi_b(nx):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ ub = nx.from_numpy(u)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M)
+
+ G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+
+ # check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+
+
+@pytest.skip_backend("jax")
+def test_sinkhorn2_variants_multi_b(nx):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ ub = nx.from_numpy(u)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M)
+
+ G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+
+ # check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
def test_sinkhorn_variants_log():
# test sinkhorn
- n = 100
+ n = 50
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -172,6 +252,7 @@ def test_sinkhorn_variants_log():
M = ot.dist(x, x)
G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
@@ -179,9 +260,30 @@ def test_sinkhorn_variants_log():
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Gl, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
- print(G0, G_green)
+
+
+def test_sinkhorn_variants_log_multib():
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
+ Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
+
+ # check values
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Gl, atol=1e-05)
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
@@ -326,10 +428,10 @@ def test_empirical_sinkhorn(nx):
a = ot.unif(n)
b = ot.unif(n)
- X_s = np.reshape(np.arange(n), (n, 1))
- X_t = np.reshape(np.arange(0, n), (n, 1))
+ X_s = np.reshape(1.0 * np.arange(n), (n, 1))
+ X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))
M = ot.dist(X_s, X_t)
- M_m = ot.dist(X_s, X_t, metric='minkowski')
+ M_m = ot.dist(X_s, X_t, metric='euclidean')
ab = nx.from_numpy(a)
bb = nx.from_numpy(b)
@@ -346,7 +448,7 @@ def test_empirical_sinkhorn(nx):
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski'))
+ G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean'))
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
@@ -378,7 +480,7 @@ def test_lazy_empirical_sinkhorn(nx):
X_s = np.reshape(np.arange(n), (n, 1))
X_t = np.reshape(np.arange(0, n), (n, 1))
M = ot.dist(X_s, X_t)
- M_m = ot.dist(X_s, X_t, metric='minkowski')
+ M_m = ot.dist(X_s, X_t, metric='euclidean')
ab = nx.from_numpy(a)
bb = nx.from_numpy(b)
@@ -398,7 +500,7 @@ def test_lazy_empirical_sinkhorn(nx):
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 19d61b1..0242d72 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -180,8 +180,8 @@ def test_sampled_gromov():
def test_gromov_barycenter():
- ns = 50
- nt = 60
+ ns = 10
+ nt = 20
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
@@ -208,8 +208,8 @@ def test_gromov_barycenter():
@pytest.mark.filterwarnings("ignore:divide")
def test_gromov_entropic_barycenter():
- ns = 20
- nt = 30
+ ns = 10
+ nt = 20
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
@@ -222,7 +222,7 @@ def test_gromov_entropic_barycenter():
[ot.unif(ns), ot.unif(nt)
], ot.unif(n_samples), [.5, .5],
'square_loss', 1e-3,
- max_iter=50, tol=1e-5,
+ max_iter=50, tol=1e-3,
verbose=True)
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
diff --git a/test/test_helpers.py b/test/test_helpers.py
index 8bd0015..cc4c90e 100644
--- a/test/test_helpers.py
+++ b/test/test_helpers.py
@@ -9,8 +9,8 @@ import sys
sys.path.append(os.path.join("ot", "helpers"))
-from openmp_helpers import get_openmp_flag, check_openmp_support # noqa
-from pre_build_helpers import _get_compiler, compile_test_program # noqa
+from openmp_helpers import get_openmp_flag, check_openmp_support # noqa
+from pre_build_helpers import _get_compiler, compile_test_program # noqa
def test_helpers():
diff --git a/test/test_utils.py b/test/test_utils.py
index 60ad5d3..0650ce2 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -7,6 +7,7 @@
import ot
import numpy as np
import sys
+import pytest
def test_proj_simplex(nx):
@@ -108,6 +109,10 @@ def test_dist():
D2 = ot.dist(x, x)
D3 = ot.dist(x)
+ D4 = ot.dist(x, x, metric='minkowski', p=0.5)
+
+ assert D4[0, 1] == D4[1, 0]
+
# dist shoul return squared euclidean
np.testing.assert_allclose(D, D2, atol=1e-14)
np.testing.assert_allclose(D, D3, atol=1e-14)
@@ -220,6 +225,13 @@ def test_deprecated_func():
class Class():
pass
+ with pytest.warns(DeprecationWarning):
+ fun()
+
+ with pytest.warns(DeprecationWarning):
+ cl = Class()
+ print(cl)
+
if sys.version_info < (3, 5):
print('Not tested')
else:
@@ -250,4 +262,7 @@ def test_BaseEstimator():
params['first'] = 'spam again'
cl.set_params(**params)
+ with pytest.raises(ValueError):
+ cl.set_params(bibi=10)
+
assert cl.first == 'spam again'