From e1b67c641da3b3e497db6811af2c200022b10302 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 3 Nov 2021 08:41:35 +0100 Subject: [WIP] Add debiased barycenter (Sinkhorn + convolutional sinkhorn) (#291) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add debiased sinkhorn barycenter + make loops pythonic * add debiased arg in tests * add 1d and 2d examples of debiased barycenters * fix doctest * fix flake8 * pep8 + make func private + add convergence warnings * remove rel paths + add rng + pylab to pyplot * fix stopping criterion debiased * pass alex * change params with new API * add logdomain barycenters + separate debiased API * test new API * fix jax read-only ? * raise error for jax * test catch jax error * fix pytest catch error * fix relative path * fix flake8 * add warn arg everywhere * fix ref number * catch warnings in tests * add contrib to readme + change ref number * fix convolution example + gallery thumbnails * increase coverage * fix flake Co-authored-by: Hicham Janati Co-authored-by: RĂ©mi Flamary Co-authored-by: Alexandre Gramfort --- test/test_bregman.py | 365 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 290 insertions(+), 75 deletions(-) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index 6923d31..edfe9c3 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -6,6 +6,8 @@ # # License: MIT License +from itertools import product + import numpy as np import pytest @@ -13,7 +15,8 @@ import ot from ot.backend import torch -def test_sinkhorn(): +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_sinkhorn(verbose, warn): # test sinkhorn n = 100 rng = np.random.RandomState(0) @@ -23,7 +26,7 @@ def test_sinkhorn(): M = ot.dist(x, x) - G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) + G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn) # check constraints np.testing.assert_allclose( @@ -31,8 +34,92 @@ def test_sinkhorn(): np.testing.assert_allclose( u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + with pytest.warns(UserWarning): + ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) +def test_convergence_warning(method): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + A = np.asarray([a1, a2]).T + M = ot.utils.dist0(n) + + with pytest.warns(UserWarning): + ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1) + + if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]: + with pytest.warns(UserWarning): + ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) + with pytest.warns(UserWarning): + ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) + + +def test_not_impemented_method(): + # test sinkhorn + w = 10 + n = w ** 2 + rng = np.random.RandomState(42) + A_img = rng.rand(2, w, w) + A_flat = A_img.reshape(n, 2) + a1, a2 = A_flat.T + M_flat = ot.utils.dist0(n) + not_implemented = "new_method" + reg = 0.01 + with pytest.raises(ValueError): + ot.sinkhorn(a1, a2, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.sinkhorn2(a1, a2, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.barycenter(A_flat, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.barycenter_debiased(A_flat, M_flat, reg, + method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.convolutional_barycenter2d(A_img, reg, + method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, + method=not_implemented) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_nan_warning(method): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + + M = ot.utils.dist0(n) + reg = 0 + with pytest.warns(UserWarning): + # warn set to False to avoid catching a convergence warning instead + ot.sinkhorn(a1, a2, M, reg, method=method, warn=False) + + +def test_sinkhorn_stabilization(): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + M = ot.utils.dist0(n) + reg = 1e-5 + loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log") + loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized") + np.testing.assert_allclose( + loss1, loss2, atol=1e-06) # cf convergence sinkhorn + -def test_sinkhorn_multi_b(): +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_log"], + [True, False], [True, False])) +def test_sinkhorn_multi_b(method, verbose, warn): # test sinkhorn n = 10 rng = np.random.RandomState(0) @@ -45,12 +132,14 @@ def test_sinkhorn_multi_b(): M = ot.dist(x, x) - loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True) + loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10, + log=True) - loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)] + loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10, + verbose=verbose, warn=warn) for k in range(3)] # check constraints np.testing.assert_allclose( - loss0, loss, atol=1e-06) # cf convergence sinkhorn + loss0, loss, atol=1e-4) # cf convergence sinkhorn def test_sinkhorn_backends(nx): @@ -67,9 +156,9 @@ def test_sinkhorn_backends(nx): G = ot.sinkhorn(a, a, M, 1) ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) - Gb = ot.sinkhorn(ab, ab, Mb, 1) + Gb = ot.sinkhorn(ab, ab, M_nx, 1) np.allclose(G, nx.to_numpy(Gb)) @@ -88,9 +177,9 @@ def test_sinkhorn2_backends(nx): G = ot.sinkhorn(a, a, M, 1) ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) - Gb = ot.sinkhorn2(ab, ab, Mb, 1) + Gb = ot.sinkhorn2(ab, ab, M_nx, 1) np.allclose(G, nx.to_numpy(Gb)) @@ -131,6 +220,12 @@ def test_sinkhorn_empty(): M = ot.dist(x, x) + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log", + verbose=True, log=True) + # check constraints + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True) # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) @@ -165,15 +260,15 @@ def test_sinkhorn_variants(nx): M = ot.dist(x, x) ub = nx.from_numpy(u) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) Ges = nx.to_numpy(ot.sinkhorn( - ub, ub, Mb, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) - G_green = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='greenkhorn', stopThr=1e-10)) + ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) + G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -199,12 +294,12 @@ def test_sinkhorn_variants_multi_b(nx): ub = nx.from_numpy(u) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -228,12 +323,12 @@ def test_sinkhorn2_variants_multi_b(nx): ub = nx.from_numpy(u) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -255,7 +350,7 @@ def test_sinkhorn_variants_log(): Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) + u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,) G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) # check values @@ -265,7 +360,8 @@ def test_sinkhorn_variants_log(): np.testing.assert_allclose(G0, G_green, atol=1e-5) -def test_sinkhorn_variants_log_multib(): +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_sinkhorn_variants_log_multib(verbose, warn): # test sinkhorn n = 50 rng = np.random.RandomState(0) @@ -278,16 +374,20 @@ def test_sinkhorn_variants_log_multib(): M = ot.dist(x, x) G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) - Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) + Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Gl, atol=1e-05) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_barycenter(nx, method): +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter(nx, method, verbose, warn): n_bins = 100 # nb bins # Gaussian distributions @@ -304,20 +404,98 @@ def test_barycenter(nx, method): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - Ab = nx.from_numpy(A) - Mb = nx.from_numpy(M) - weightsb = nx.from_numpy(weights) + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) + weights_nx = nx.from_numpy(weights) + reg = 1e-2 + + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method) + else: + # wasserstein + bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) + + ot.bregman.barycenter(A_nx, M_nx, reg, log=True) + + +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter_debiased(nx, method, verbose, warn): + n_bins = 100 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) + weights_nx = nx.from_numpy(weights) # wasserstein reg = 1e-2 - bary_wass_np, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True) - bary_wass, _ = ot.bregman.barycenter(Ab, Mb, reg, weightsb, method=method, log=True) - bary_wass = nx.to_numpy(bary_wass) + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) + else: + bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, + verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) + + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) - np.testing.assert_allclose(1, np.sum(bary_wass)) - np.testing.assert_allclose(bary_wass, bary_wass_np) - ot.bregman.barycenter(Ab, Mb, reg, log=True, verbose=True) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_convergence_warning_barycenters(method): + w = 10 + n_bins = w ** 2 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + A_img = A.reshape(2, w, w) + A_img /= A_img.sum((1, 2))[:, None, None] + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + reg = 0.1 + with pytest.warns(UserWarning): + ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.convolutional_barycenter2d(A_img, reg, weights, + method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights, + method=method, numItermax=1) def test_barycenter_stabilization(nx): @@ -337,31 +515,64 @@ def test_barycenter_stabilization(nx): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - Ab = nx.from_numpy(A) - Mb = nx.from_numpy(M) + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) weights_b = nx.from_numpy(weights) # wasserstein reg = 1e-2 bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) bar_stable = nx.to_numpy(ot.bregman.barycenter( - Ab, Mb, reg, weights_b, method="sinkhorn_stabilized", + A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized", stopThr=1e-8, verbose=True )) bar = nx.to_numpy(ot.bregman.barycenter( - Ab, Mb, reg, weights_b, method="sinkhorn", + A_nx, M_nx, reg, weights_b, method="sinkhorn", stopThr=1e-8, verbose=True )) np.testing.assert_allclose(bar, bar_stable) np.testing.assert_allclose(bar, bar_np) -def test_wasserstein_bary_2d(nx): - size = 100 # size of a square image - a1 = np.random.randn(size, size) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d(nx, method): + size = 20 # size of a square image + a1 = np.random.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) + a2 = np.random.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + A_nx = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) + else: + bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_debiased(nx, method): + size = 20 # size of a square image + a1 = np.random.rand(size, size) a1 += a1.min() a1 = a1 / np.sum(a1) - a2 = np.random.randn(size, size) + a2 = np.random.rand(size, size) a2 += a2.min() a2 = a2 / np.sum(a2) # creating matrix A containing all distributions @@ -369,18 +580,22 @@ def test_wasserstein_bary_2d(nx): A[0, :, :] = a1 A[1, :, :] = a2 - Ab = nx.from_numpy(A) + A_nx = nx.from_numpy(A) # wasserstein reg = 1e-2 - bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg)) + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) + else: + bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) - np.testing.assert_allclose(1, np.sum(bary_wass)) - np.testing.assert_allclose(bary_wass, bary_wass_np) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) def test_unmix(nx): @@ -405,20 +620,20 @@ def test_unmix(nx): ab = nx.from_numpy(a) Db = nx.from_numpy(D) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) M0b = nx.from_numpy(M0) h0b = nx.from_numpy(h0) # wasserstein reg = 1e-3 um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) - um = nx.to_numpy(ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01)) + um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) np.testing.assert_allclose(um, um_np) - ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, + ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01, log=True, verbose=True) @@ -437,22 +652,22 @@ def test_empirical_sinkhorn(nx): bb = nx.from_numpy(b) X_sb = nx.from_numpy(X_s) X_tb = nx.from_numpy(X_t) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) M_mb = nx.from_numpy(M_m, type_as=ab) G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) - sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True) G_log = nx.to_numpy(G_log) - sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) - loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints np.testing.assert_allclose( @@ -486,18 +701,18 @@ def test_lazy_empirical_sinkhorn(nx): bb = nx.from_numpy(b) X_sb = nx.from_numpy(X_s) X_tb = nx.from_numpy(X_t) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) M_mb = nx.from_numpy(M_m, type_as=ab) f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) - sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) - sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) @@ -507,7 +722,7 @@ def test_lazy_empirical_sinkhorn(nx): loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) - loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints np.testing.assert_allclose( @@ -541,13 +756,13 @@ def test_empirical_sinkhorn_divergence(nx): bb = nx.from_numpy(b) X_sb = nx.from_numpy(X_s) X_tb = nx.from_numpy(X_t) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) M_sb = nx.from_numpy(M_s, type_as=ab) M_tb = nx.from_numpy(M_t, type_as=ab) emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) sinkhorn_div = nx.to_numpy( - ot.sinkhorn2(ab, bb, Mb, 1) + ot.sinkhorn2(ab, bb, M_nx, 1) - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) ) @@ -580,14 +795,14 @@ def test_stabilized_vs_sinkhorn_multidim(nx): ab = nx.from_numpy(a) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) - G, log = ot.bregman.sinkhorn(ab, bb, Mb, reg=epsilon, + G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, method="sinkhorn_stabilized", log=True) G = nx.to_numpy(G) - G2, log2 = ot.bregman.sinkhorn(ab, bb, Mb, epsilon, + G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon, method="sinkhorn", log=True) G2 = nx.to_numpy(G2) @@ -642,14 +857,14 @@ def test_screenkhorn(nx): ab = nx.from_numpy(a) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) # np sinkhorn G_sink_np = ot.sinkhorn(a, b, M, 1e-03) # sinkhorn - G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1e-03)) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, Mb, 1e-03, uniform=True, verbose=True)) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True)) # check marginals np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) @@ -659,10 +874,10 @@ def test_screenkhorn(nx): def test_convolutional_barycenter_non_square(nx): # test for image with height not equal width A = np.ones((2, 2, 3)) / (2 * 3) - Ab = nx.from_numpy(A) + A_nx = nx.from_numpy(A) b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03) - b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, 1e-03)) + b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, 1e-03)) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) -- cgit v1.2.3