summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py718
1 files changed, 642 insertions, 76 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6aa4e08..830052d 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -2,15 +2,21 @@
# Author: Remi Flamary <remi.flamary@unice.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
+from itertools import product
+
import numpy as np
-import ot
import pytest
+import ot
+from ot.backend import torch
+
-def test_sinkhorn():
+@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
+def test_sinkhorn(verbose, warn):
# test sinkhorn
n = 100
rng = np.random.RandomState(0)
@@ -20,14 +26,189 @@ def test_sinkhorn():
M = ot.dist(x, x)
- G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
+ G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+ with pytest.warns(UserWarning):
+ ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_epsilon_scaling",
+ "greenkhorn",
+ "sinkhorn_log"])
+def test_convergence_warning(method):
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+ A = np.asarray([a1, a2]).T
+ M = ot.utils.dist0(n)
+
+ with pytest.warns(UserWarning):
+ ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1)
+
+ if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]:
+ with pytest.warns(UserWarning):
+ ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
+
+
+def test_not_impemented_method():
+ # test sinkhorn
+ w = 10
+ n = w ** 2
+ rng = np.random.RandomState(42)
+ A_img = rng.rand(2, w, w)
+ A_flat = A_img.reshape(n, 2)
+ a1, a2 = A_flat.T
+ M_flat = ot.utils.dist0(n)
+ not_implemented = "new_method"
+ reg = 0.01
+ with pytest.raises(ValueError):
+ ot.sinkhorn(a1, a2, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.sinkhorn2(a1, a2, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.barycenter(A_flat, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.barycenter_debiased(A_flat, M_flat, reg,
+ method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.convolutional_barycenter2d(A_img, reg,
+ method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.convolutional_barycenter2d_debiased(A_img, reg,
+ method=not_implemented)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_nan_warning(method):
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+
+ M = ot.utils.dist0(n)
+ reg = 0
+ with pytest.warns(UserWarning):
+ # warn set to False to avoid catching a convergence warning instead
+ ot.sinkhorn(a1, a2, M, reg, method=method, warn=False)
+
+
+def test_sinkhorn_stabilization():
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+ M = ot.utils.dist0(n)
+ reg = 1e-5
+ loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log")
+ loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized")
+ np.testing.assert_allclose(
+ loss1, loss2, atol=1e-06) # cf convergence sinkhorn
+
+
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_sinkhorn_multi_b(method, verbose, warn):
+ # test sinkhorn
+ n = 10
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10,
+ log=True)
+
+ loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10,
+ verbose=verbose, warn=warn) for k in range(3)]
+ # check constraints
+ np.testing.assert_allclose(
+ loss0, loss, atol=1e-4) # cf convergence sinkhorn
+
+
+def test_sinkhorn_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ M_nx = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn(ab, ab, M_nx, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+def test_sinkhorn2_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ M_nx = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn2(ab, ab, M_nx, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+def test_sinkhorn2_gradients():
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ if torch:
+
+ a1 = torch.tensor(a, requires_grad=True)
+ b1 = torch.tensor(a, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.sinkhorn2(a1, b1, M1, 1)
+
+ val.backward()
+
+ assert a1.shape == a1.grad.shape
+ assert b1.shape == b1.grad.shape
+ assert M1.shape == M1.grad.shape
+
def test_sinkhorn_empty():
# test sinkhorn
@@ -39,21 +220,27 @@ def test_sinkhorn_empty():
M = ot.dist(x, x)
+ G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log",
+ verbose=True, log=True)
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
+
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10,
method='sinkhorn_stabilized', verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
G, log = ot.sinkhorn(
[], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling',
verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
@@ -61,7 +248,8 @@ def test_sinkhorn_empty():
ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
-def test_sinkhorn_variants():
+@pytest.skip_backend("jax")
+def test_sinkhorn_variants(nx):
# test sinkhorn
n = 100
rng = np.random.RandomState(0)
@@ -71,22 +259,131 @@ def test_sinkhorn_variants():
M = ot.dist(x, x)
- G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
- Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
- Ges = ot.sinkhorn(
- u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
- G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
+ ub = nx.from_numpy(u)
+ M_nx = nx.from_numpy(M)
+
+ G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Ges = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10))
+ G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10))
# check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
- print(G0, G_green)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_epsilon_scaling",
+ "greenkhorn",
+ "sinkhorn_log"])
+@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str)
+@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str)
+def test_sinkhorn_variants_dtype_device(nx, method):
+ n = 100
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ ub = nx.from_numpy(u, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
+
+ Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+
+ nx.assert_same_dtype_device(Mb, Gb)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"])
+def test_sinkhorn2_variants_dtype_device(nx, method):
+ n = 100
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ ub = nx.from_numpy(u, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
+
+ lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+
+ nx.assert_same_dtype_device(Mb, lossb)
+
+
+@pytest.skip_backend("jax")
+def test_sinkhorn_variants_multi_b(nx):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ ub = nx.from_numpy(u)
+ bb = nx.from_numpy(b)
+ M_nx = nx.from_numpy(M)
+
+ G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+
+ # check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+
+
+@pytest.skip_backend("jax")
+def test_sinkhorn2_variants_multi_b(nx):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ ub = nx.from_numpy(u)
+ bb = nx.from_numpy(b)
+ M_nx = nx.from_numpy(M)
+
+ G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+
+ # check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
def test_sinkhorn_variants_log():
# test sinkhorn
- n = 100
+ n = 50
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -95,20 +392,87 @@ def test_sinkhorn_variants_log():
M = ot.dist(x, x)
G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
- u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
+ u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,)
G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Gl, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
- print(G0, G_green)
-@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_barycenter(method):
+@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
+def test_sinkhorn_variants_log_multib(verbose, warn):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True,
+ verbose=verbose, warn=warn)
+ Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True,
+ verbose=verbose, warn=warn)
+
+ # check values
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Gl, atol=1e-05)
+
+
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_barycenter(nx, method, verbose, warn):
+ n_bins = 100 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
+ weights_nx = nx.from_numpy(weights)
+ reg = 1e-2
+
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method)
+ else:
+ # wasserstein
+ bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(bary_wass, bary_wass_np)
+
+ ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
+
+
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_barycenter_debiased(nx, method, verbose, warn):
n_bins = 100 # nb bins
# Gaussian distributions
@@ -125,16 +489,61 @@ def test_barycenter(method):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
+ weights_nx = nx.from_numpy(weights)
+
# wasserstein
reg = 1e-2
- bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method)
+ else:
+ bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method,
+ verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5)
- np.testing.assert_allclose(1, np.sum(bary_wass))
+ ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False)
- ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_convergence_warning_barycenters(method):
+ w = 10
+ n_bins = w ** 2 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+ A_img = A.reshape(2, w, w)
+ A_img /= A_img.sum((1, 2))[:, None, None]
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
-def test_barycenter_stabilization():
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+ reg = 0.1
+ with pytest.warns(UserWarning):
+ ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.convolutional_barycenter2d(A_img, reg, weights,
+ method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights,
+ method=method, numItermax=1)
+
+
+def test_barycenter_stabilization(nx):
n_bins = 100 # nb bins
# Gaussian distributions
@@ -151,22 +560,64 @@ def test_barycenter_stabilization():
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
+ weights_b = nx.from_numpy(weights)
+
# wasserstein
reg = 1e-2
- bar_stable = ot.bregman.barycenter(A, M, reg, weights,
- method="sinkhorn_stabilized",
- stopThr=1e-8, verbose=True)
- bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
- stopThr=1e-8, verbose=True)
+ bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
+ bar_stable = nx.to_numpy(ot.bregman.barycenter(
+ A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized",
+ stopThr=1e-8, verbose=True
+ ))
+ bar = nx.to_numpy(ot.bregman.barycenter(
+ A_nx, M_nx, reg, weights_b, method="sinkhorn",
+ stopThr=1e-8, verbose=True
+ ))
np.testing.assert_allclose(bar, bar_stable)
+ np.testing.assert_allclose(bar, bar_np)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_wasserstein_bary_2d(nx, method):
+ size = 20 # size of a square image
+ a1 = np.random.rand(size, size)
+ a1 += a1.min()
+ a1 = a1 / np.sum(a1)
+ a2 = np.random.rand(size, size)
+ a2 += a2.min()
+ a2 = a2 / np.sum(a2)
+ # creating matrix A containing all distributions
+ A = np.zeros((2, size, size))
+ A[0, :, :] = a1
+ A[1, :, :] = a2
+ A_nx = nx.from_numpy(A)
-def test_wasserstein_bary_2d():
- size = 100 # size of a square image
- a1 = np.random.randn(size, size)
+ # wasserstein
+ reg = 1e-2
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
+ else:
+ bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method)
+ bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
+
+ np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+
+ # help in checking if log and verbose do not bug the function
+ ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_wasserstein_bary_2d_debiased(nx, method):
+ size = 20 # size of a square image
+ a1 = np.random.rand(size, size)
a1 += a1.min()
a1 = a1 / np.sum(a1)
- a2 = np.random.randn(size, size)
+ a2 = np.random.rand(size, size)
a2 += a2.min()
a2 = a2 / np.sum(a2)
# creating matrix A containing all distributions
@@ -174,17 +625,25 @@ def test_wasserstein_bary_2d():
A[0, :, :] = a1
A[1, :, :] = a2
+ A_nx = nx.from_numpy(A)
+
# wasserstein
reg = 1e-2
- bary_wass = ot.bregman.convolutional_barycenter2d(A, reg)
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
+ else:
+ bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method)
+ bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
- np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
- # help in checking if log and verbose do not bug the function
- ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+ # help in checking if log and verbose do not bug the function
+ ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
-def test_unmix():
+def test_unmix(nx):
n_bins = 50 # nb bins
# Gaussian distributions
@@ -204,41 +663,58 @@ def test_unmix():
M0 /= M0.max()
h0 = ot.unif(2)
+ ab = nx.from_numpy(a)
+ Db = nx.from_numpy(D)
+ M_nx = nx.from_numpy(M)
+ M0b = nx.from_numpy(M0)
+ h0b = nx.from_numpy(h0)
+
# wasserstein
reg = 1e-3
- um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, )
+ um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01)
+ um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01))
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
+ np.testing.assert_allclose(um, um_np)
- ot.bregman.unmix(a, D, M, M0, h0, reg,
+ ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg,
1, alpha=0.01, log=True, verbose=True)
-def test_empirical_sinkhorn():
+def test_empirical_sinkhorn(nx):
# test sinkhorn
- n = 100
+ n = 10
a = ot.unif(n)
b = ot.unif(n)
- X_s = np.reshape(np.arange(n), (n, 1))
- X_t = np.reshape(np.arange(0, n), (n, 1))
+ X_s = np.reshape(1.0 * np.arange(n), (n, 1))
+ X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))
M = ot.dist(X_s, X_t)
- M_m = ot.dist(X_s, X_t, metric='minkowski')
+ M_m = ot.dist(X_s, X_t, metric='euclidean')
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ M_nx = nx.from_numpy(M, type_as=ab)
+ M_mb = nx.from_numpy(M_m, type_as=ab)
- G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
- sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+ G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1))
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
- G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True)
- sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+ G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True)
+ G_log = nx.to_numpy(G_log)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
+ sinkhorn_log = nx.to_numpy(sinkhorn_log)
- G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski')
- sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
+ G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean'))
+ sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1)
- loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+ loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
- # check constratints
+ # check constraints
np.testing.assert_allclose(
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
np.testing.assert_allclose(
@@ -254,34 +730,98 @@ def test_empirical_sinkhorn():
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
-def test_empirical_sinkhorn_divergence():
- # Test sinkhorn divergence
+def test_lazy_empirical_sinkhorn(nx):
+ # test sinkhorn
n = 10
a = ot.unif(n)
b = ot.unif(n)
+ numIterMax = 1000
+
+ X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1))
+ X_t = np.reshape(np.arange(0, n, dtype=np.float64), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_m = ot.dist(X_s, X_t, metric='euclidean')
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ M_nx = nx.from_numpy(M, type_as=ab)
+ M_mb = nx.from_numpy(M_m, type_as=ab)
+
+ f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
+ G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
+
+ f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
+ G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
+ sinkhorn_log = nx.to_numpy(sinkhorn_log)
+
+ f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
+ G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
+ sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
+
+ loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
+
+ # check constraints
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+
+
+def test_empirical_sinkhorn_divergence(nx):
+ # Test sinkhorn divergence
+ n = 10
+ a = np.linspace(1, n, n)
+ a /= a.sum()
+ b = ot.unif(n)
X_s = np.reshape(np.arange(n), (n, 1))
X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
M = ot.dist(X_s, X_t)
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1)
- sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1))
-
- emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True)
- sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True)
- sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True)
- sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True)
- sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b)
-
- # check constratints
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ M_nx = nx.from_numpy(M, type_as=ab)
+ M_sb = nx.from_numpy(M_s, type_as=ab)
+ M_tb = nx.from_numpy(M_t, type_as=ab)
+
+ emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
+ sinkhorn_div = nx.to_numpy(
+ ot.sinkhorn2(ab, bb, M_nx, 1)
+ - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1)
+ - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1)
+ )
+ emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b)
+
+ # check constraints
+ np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05)
np.testing.assert_allclose(
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
- np.testing.assert_allclose(
- emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn
+
+ ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True)
-def test_stabilized_vs_sinkhorn_multidim():
+def test_stabilized_vs_sinkhorn_multidim(nx):
# test if stable version matches sinkhorn
# for multidimensional inputs
n = 100
@@ -297,12 +837,21 @@ def test_stabilized_vs_sinkhorn_multidim():
M = ot.utils.dist0(n)
M /= np.median(M)
epsilon = 0.1
- G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon,
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ M_nx = nx.from_numpy(M, type_as=ab)
+
+ G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
+ G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon,
method="sinkhorn_stabilized",
log=True)
- G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon,
+ G = nx.to_numpy(G)
+ G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon,
method="sinkhorn", log=True)
+ G2 = nx.to_numpy(G2)
+ np.testing.assert_allclose(G_np, G2)
np.testing.assert_allclose(G, G2)
@@ -320,8 +869,9 @@ def test_implemented_methods():
# make dists unbalanced
b = ot.utils.unif(n)
A = rng.rand(n, 2)
+ A /= A.sum(0, keepdims=True)
M = ot.dist(x, x)
- epsilon = 1.
+ epsilon = 1.0
for method in IMPLEMENTED_METHODS:
ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
@@ -338,7 +888,9 @@ def test_implemented_methods():
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
-def test_screenkhorn():
+@pytest.skip_backend("jax")
+@pytest.mark.filterwarnings("ignore:Bottleneck")
+def test_screenkhorn(nx):
# test screenkhorn
rng = np.random.RandomState(0)
n = 100
@@ -347,17 +899,31 @@ def test_screenkhorn():
x = rng.randn(n, 2)
M = ot.dist(x, x)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ M_nx = nx.from_numpy(M, type_as=ab)
+
+ # np sinkhorn
+ G_sink_np = ot.sinkhorn(a, b, M, 1e-03)
# sinkhorn
- G_sink = ot.sinkhorn(a, b, M, 1e-03)
+ G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03))
# screenkhorn
- G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True)
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True))
# check marginals
+ np.testing.assert_allclose(G_sink_np, G_sink)
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
-def test_convolutional_barycenter_non_square():
+def test_convolutional_barycenter_non_square(nx):
# test for image with height not equal width
A = np.ones((2, 2, 3)) / (2 * 3)
- b = ot.bregman.convolutional_barycenter2d(A, 1e-03)
+ A_nx = nx.from_numpy(A)
+
+ b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03)
+ b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, 1e-03))
+
+ np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
+ np.testing.assert_allclose(b, b_np)