summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati100@gmail.com>2021-11-03 08:41:35 +0100
committerGitHub <noreply@github.com>2021-11-03 08:41:35 +0100
commite1b67c641da3b3e497db6811af2c200022b10302 (patch)
tree44d42e1ae50d653bb07dd6ef9c1de14f71b21642 /test
parent61340d526702616ff000d9e1cf71f52dd199a103 (diff)
[WIP] Add debiased barycenter (Sinkhorn + convolutional sinkhorn) (#291)
* add debiased sinkhorn barycenter + make loops pythonic * add debiased arg in tests * add 1d and 2d examples of debiased barycenters * fix doctest * fix flake8 * pep8 + make func private + add convergence warnings * remove rel paths + add rng + pylab to pyplot * fix stopping criterion debiased * pass alex * change params with new API * add logdomain barycenters + separate debiased API * test new API * fix jax read-only ? * raise error for jax * test catch jax error * fix pytest catch error * fix relative path * fix flake8 * add warn arg everywhere * fix ref number * catch warnings in tests * add contrib to readme + change ref number * fix convolution example + gallery thumbnails * increase coverage * fix flake Co-authored-by: Hicham Janati <hicham.janati@inria.fr> Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py365
1 files changed, 290 insertions, 75 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6923d31..edfe9c3 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -6,6 +6,8 @@
#
# License: MIT License
+from itertools import product
+
import numpy as np
import pytest
@@ -13,7 +15,8 @@ import ot
from ot.backend import torch
-def test_sinkhorn():
+@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
+def test_sinkhorn(verbose, warn):
# test sinkhorn
n = 100
rng = np.random.RandomState(0)
@@ -23,7 +26,7 @@ def test_sinkhorn():
M = ot.dist(x, x)
- G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
+ G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn)
# check constraints
np.testing.assert_allclose(
@@ -31,8 +34,92 @@ def test_sinkhorn():
np.testing.assert_allclose(
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+ with pytest.warns(UserWarning):
+ ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_epsilon_scaling",
+ "greenkhorn",
+ "sinkhorn_log"])
+def test_convergence_warning(method):
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+ A = np.asarray([a1, a2]).T
+ M = ot.utils.dist0(n)
+
+ with pytest.warns(UserWarning):
+ ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1)
+
+ if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]:
+ with pytest.warns(UserWarning):
+ ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
+
+
+def test_not_impemented_method():
+ # test sinkhorn
+ w = 10
+ n = w ** 2
+ rng = np.random.RandomState(42)
+ A_img = rng.rand(2, w, w)
+ A_flat = A_img.reshape(n, 2)
+ a1, a2 = A_flat.T
+ M_flat = ot.utils.dist0(n)
+ not_implemented = "new_method"
+ reg = 0.01
+ with pytest.raises(ValueError):
+ ot.sinkhorn(a1, a2, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.sinkhorn2(a1, a2, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.barycenter(A_flat, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.barycenter_debiased(A_flat, M_flat, reg,
+ method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.convolutional_barycenter2d(A_img, reg,
+ method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.convolutional_barycenter2d_debiased(A_img, reg,
+ method=not_implemented)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_nan_warning(method):
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+
+ M = ot.utils.dist0(n)
+ reg = 0
+ with pytest.warns(UserWarning):
+ # warn set to False to avoid catching a convergence warning instead
+ ot.sinkhorn(a1, a2, M, reg, method=method, warn=False)
+
+
+def test_sinkhorn_stabilization():
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+ M = ot.utils.dist0(n)
+ reg = 1e-5
+ loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log")
+ loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized")
+ np.testing.assert_allclose(
+ loss1, loss2, atol=1e-06) # cf convergence sinkhorn
+
-def test_sinkhorn_multi_b():
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_sinkhorn_multi_b(method, verbose, warn):
# test sinkhorn
n = 10
rng = np.random.RandomState(0)
@@ -45,12 +132,14 @@ def test_sinkhorn_multi_b():
M = ot.dist(x, x)
- loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True)
+ loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10,
+ log=True)
- loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)]
+ loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10,
+ verbose=verbose, warn=warn) for k in range(3)]
# check constraints
np.testing.assert_allclose(
- loss0, loss, atol=1e-06) # cf convergence sinkhorn
+ loss0, loss, atol=1e-4) # cf convergence sinkhorn
def test_sinkhorn_backends(nx):
@@ -67,9 +156,9 @@ def test_sinkhorn_backends(nx):
G = ot.sinkhorn(a, a, M, 1)
ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ M_nx = nx.from_numpy(M)
- Gb = ot.sinkhorn(ab, ab, Mb, 1)
+ Gb = ot.sinkhorn(ab, ab, M_nx, 1)
np.allclose(G, nx.to_numpy(Gb))
@@ -88,9 +177,9 @@ def test_sinkhorn2_backends(nx):
G = ot.sinkhorn(a, a, M, 1)
ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ M_nx = nx.from_numpy(M)
- Gb = ot.sinkhorn2(ab, ab, Mb, 1)
+ Gb = ot.sinkhorn2(ab, ab, M_nx, 1)
np.allclose(G, nx.to_numpy(Gb))
@@ -131,6 +220,12 @@ def test_sinkhorn_empty():
M = ot.dist(x, x)
+ G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log",
+ verbose=True, log=True)
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
+
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True)
# check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
@@ -165,15 +260,15 @@ def test_sinkhorn_variants(nx):
M = ot.dist(x, x)
ub = nx.from_numpy(u)
- Mb = nx.from_numpy(M)
+ M_nx = nx.from_numpy(M)
G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
Ges = nx.to_numpy(ot.sinkhorn(
- ub, ub, Mb, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10))
- G_green = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='greenkhorn', stopThr=1e-10))
+ ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10))
+ G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -199,12 +294,12 @@ def test_sinkhorn_variants_multi_b(nx):
ub = nx.from_numpy(u)
bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M)
+ M_nx = nx.from_numpy(M)
G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -228,12 +323,12 @@ def test_sinkhorn2_variants_multi_b(nx):
ub = nx.from_numpy(u)
bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M)
+ M_nx = nx.from_numpy(M)
G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -255,7 +350,7 @@ def test_sinkhorn_variants_log():
Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
- u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
+ u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,)
G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
# check values
@@ -265,7 +360,8 @@ def test_sinkhorn_variants_log():
np.testing.assert_allclose(G0, G_green, atol=1e-5)
-def test_sinkhorn_variants_log_multib():
+@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
+def test_sinkhorn_variants_log_multib(verbose, warn):
# test sinkhorn
n = 50
rng = np.random.RandomState(0)
@@ -278,16 +374,20 @@ def test_sinkhorn_variants_log_multib():
M = ot.dist(x, x)
G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
- Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
- Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True,
+ verbose=verbose, warn=warn)
+ Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True,
+ verbose=verbose, warn=warn)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Gl, atol=1e-05)
-@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_barycenter(nx, method):
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_barycenter(nx, method, verbose, warn):
n_bins = 100 # nb bins
# Gaussian distributions
@@ -304,20 +404,98 @@ def test_barycenter(nx, method):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
- Ab = nx.from_numpy(A)
- Mb = nx.from_numpy(M)
- weightsb = nx.from_numpy(weights)
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
+ weights_nx = nx.from_numpy(weights)
+ reg = 1e-2
+
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method)
+ else:
+ # wasserstein
+ bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(bary_wass, bary_wass_np)
+
+ ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
+
+
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_barycenter_debiased(nx, method, verbose, warn):
+ n_bins = 100 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
+ weights_nx = nx.from_numpy(weights)
# wasserstein
reg = 1e-2
- bary_wass_np, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
- bary_wass, _ = ot.bregman.barycenter(Ab, Mb, reg, weightsb, method=method, log=True)
- bary_wass = nx.to_numpy(bary_wass)
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method)
+ else:
+ bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method,
+ verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5)
+
+ ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False)
- np.testing.assert_allclose(1, np.sum(bary_wass))
- np.testing.assert_allclose(bary_wass, bary_wass_np)
- ot.bregman.barycenter(Ab, Mb, reg, log=True, verbose=True)
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_convergence_warning_barycenters(method):
+ w = 10
+ n_bins = w ** 2 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+ A_img = A.reshape(2, w, w)
+ A_img /= A_img.sum((1, 2))[:, None, None]
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+ reg = 0.1
+ with pytest.warns(UserWarning):
+ ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.convolutional_barycenter2d(A_img, reg, weights,
+ method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights,
+ method=method, numItermax=1)
def test_barycenter_stabilization(nx):
@@ -337,31 +515,64 @@ def test_barycenter_stabilization(nx):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
- Ab = nx.from_numpy(A)
- Mb = nx.from_numpy(M)
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
weights_b = nx.from_numpy(weights)
# wasserstein
reg = 1e-2
bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
bar_stable = nx.to_numpy(ot.bregman.barycenter(
- Ab, Mb, reg, weights_b, method="sinkhorn_stabilized",
+ A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized",
stopThr=1e-8, verbose=True
))
bar = nx.to_numpy(ot.bregman.barycenter(
- Ab, Mb, reg, weights_b, method="sinkhorn",
+ A_nx, M_nx, reg, weights_b, method="sinkhorn",
stopThr=1e-8, verbose=True
))
np.testing.assert_allclose(bar, bar_stable)
np.testing.assert_allclose(bar, bar_np)
-def test_wasserstein_bary_2d(nx):
- size = 100 # size of a square image
- a1 = np.random.randn(size, size)
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_wasserstein_bary_2d(nx, method):
+ size = 20 # size of a square image
+ a1 = np.random.rand(size, size)
+ a1 += a1.min()
+ a1 = a1 / np.sum(a1)
+ a2 = np.random.rand(size, size)
+ a2 += a2.min()
+ a2 = a2 / np.sum(a2)
+ # creating matrix A containing all distributions
+ A = np.zeros((2, size, size))
+ A[0, :, :] = a1
+ A[1, :, :] = a2
+
+ A_nx = nx.from_numpy(A)
+
+ # wasserstein
+ reg = 1e-2
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
+ else:
+ bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method)
+ bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
+
+ np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+
+ # help in checking if log and verbose do not bug the function
+ ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_wasserstein_bary_2d_debiased(nx, method):
+ size = 20 # size of a square image
+ a1 = np.random.rand(size, size)
a1 += a1.min()
a1 = a1 / np.sum(a1)
- a2 = np.random.randn(size, size)
+ a2 = np.random.rand(size, size)
a2 += a2.min()
a2 = a2 / np.sum(a2)
# creating matrix A containing all distributions
@@ -369,18 +580,22 @@ def test_wasserstein_bary_2d(nx):
A[0, :, :] = a1
A[1, :, :] = a2
- Ab = nx.from_numpy(A)
+ A_nx = nx.from_numpy(A)
# wasserstein
reg = 1e-2
- bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg)
- bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg))
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
+ else:
+ bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method)
+ bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
- np.testing.assert_allclose(1, np.sum(bary_wass))
- np.testing.assert_allclose(bary_wass, bary_wass_np)
+ np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
- # help in checking if log and verbose do not bug the function
- ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+ # help in checking if log and verbose do not bug the function
+ ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
def test_unmix(nx):
@@ -405,20 +620,20 @@ def test_unmix(nx):
ab = nx.from_numpy(a)
Db = nx.from_numpy(D)
- Mb = nx.from_numpy(M)
+ M_nx = nx.from_numpy(M)
M0b = nx.from_numpy(M0)
h0b = nx.from_numpy(h0)
# wasserstein
reg = 1e-3
um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01)
- um = nx.to_numpy(ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01))
+ um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01))
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
np.testing.assert_allclose(um, um_np)
- ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg,
+ ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg,
1, alpha=0.01, log=True, verbose=True)
@@ -437,22 +652,22 @@ def test_empirical_sinkhorn(nx):
bb = nx.from_numpy(b)
X_sb = nx.from_numpy(X_s)
X_tb = nx.from_numpy(X_t)
- Mb = nx.from_numpy(M, type_as=ab)
+ M_nx = nx.from_numpy(M, type_as=ab)
M_mb = nx.from_numpy(M_m, type_as=ab)
G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1))
- sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1))
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True)
G_log = nx.to_numpy(G_log)
- sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean'))
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
- loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1))
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
# check constraints
np.testing.assert_allclose(
@@ -486,18 +701,18 @@ def test_lazy_empirical_sinkhorn(nx):
bb = nx.from_numpy(b)
X_sb = nx.from_numpy(X_s)
X_tb = nx.from_numpy(X_t)
- Mb = nx.from_numpy(M, type_as=ab)
+ M_nx = nx.from_numpy(M, type_as=ab)
M_mb = nx.from_numpy(M_m, type_as=ab)
f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
- sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1))
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
- sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
@@ -507,7 +722,7 @@ def test_lazy_empirical_sinkhorn(nx):
loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
- loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1))
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
# check constraints
np.testing.assert_allclose(
@@ -541,13 +756,13 @@ def test_empirical_sinkhorn_divergence(nx):
bb = nx.from_numpy(b)
X_sb = nx.from_numpy(X_s)
X_tb = nx.from_numpy(X_t)
- Mb = nx.from_numpy(M, type_as=ab)
+ M_nx = nx.from_numpy(M, type_as=ab)
M_sb = nx.from_numpy(M_s, type_as=ab)
M_tb = nx.from_numpy(M_t, type_as=ab)
emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
sinkhorn_div = nx.to_numpy(
- ot.sinkhorn2(ab, bb, Mb, 1)
+ ot.sinkhorn2(ab, bb, M_nx, 1)
- 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1)
- 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1)
)
@@ -580,14 +795,14 @@ def test_stabilized_vs_sinkhorn_multidim(nx):
ab = nx.from_numpy(a)
bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M, type_as=ab)
+ M_nx = nx.from_numpy(M, type_as=ab)
G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
- G, log = ot.bregman.sinkhorn(ab, bb, Mb, reg=epsilon,
+ G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon,
method="sinkhorn_stabilized",
log=True)
G = nx.to_numpy(G)
- G2, log2 = ot.bregman.sinkhorn(ab, bb, Mb, epsilon,
+ G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon,
method="sinkhorn", log=True)
G2 = nx.to_numpy(G2)
@@ -642,14 +857,14 @@ def test_screenkhorn(nx):
ab = nx.from_numpy(a)
bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M, type_as=ab)
+ M_nx = nx.from_numpy(M, type_as=ab)
# np sinkhorn
G_sink_np = ot.sinkhorn(a, b, M, 1e-03)
# sinkhorn
- G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1e-03))
+ G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03))
# screenkhorn
- G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, Mb, 1e-03, uniform=True, verbose=True))
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True))
# check marginals
np.testing.assert_allclose(G_sink_np, G_sink)
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
@@ -659,10 +874,10 @@ def test_screenkhorn(nx):
def test_convolutional_barycenter_non_square(nx):
# test for image with height not equal width
A = np.ones((2, 2, 3)) / (2 * 3)
- Ab = nx.from_numpy(A)
+ A_nx = nx.from_numpy(A)
b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03)
- b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, 1e-03))
+ b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, 1e-03))
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)