From 7a65086dd340265d0223eb8ffb5c9a5152a82dff Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Mon, 25 Oct 2021 11:36:21 +0200 Subject: [MRG] Bregman backend (#280) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Bregman * Resolve conflicts * Bug solve * Bregman updated for JAX compatibility * Tests coherence between backend improved * No longer enforcing 64 bits operations on Jax except for tests * Now using mixtures, to make backend dependent tests with less code * Better test skipping code * Pep8 + test optimizations * redundancy removed * Docs * Typo corrected * Typo * Typo * Docs * Docs * pep8 * Backend docs * Prettier docs * Mistake corrected * small changes * Better wording Co-authored-by: RĂ©mi Flamary --- test/test_bregman.py | 217 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 150 insertions(+), 67 deletions(-) (limited to 'test/test_bregman.py') diff --git a/test/test_bregman.py b/test/test_bregman.py index 88166a5..942cb6d 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -10,11 +10,8 @@ import numpy as np import pytest import ot -from ot.backend import get_backend_list from ot.backend import torch -backend_list = get_backend_list() - def test_sinkhorn(): # test sinkhorn @@ -28,14 +25,13 @@ def test_sinkhorn(): G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( u, G.sum(0), atol=1e-05) # cf convergence sinkhorn -@pytest.mark.parametrize('nx', backend_list) def test_sinkhorn_backends(nx): n_samples = 100 n_features = 2 @@ -57,7 +53,6 @@ def test_sinkhorn_backends(nx): np.allclose(G, nx.to_numpy(Gb)) -@pytest.mark.parametrize('nx', backend_list) def test_sinkhorn2_backends(nx): n_samples = 100 n_features = 2 @@ -116,20 +111,20 @@ def test_sinkhorn_empty(): M = ot.dist(x, x) G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method='sinkhorn_stabilized', verbose=True, log=True) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) G, log = ot.sinkhorn( [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling', verbose=True, log=True) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) @@ -137,7 +132,8 @@ def test_sinkhorn_empty(): ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True) -def test_sinkhorn_variants(): +@pytest.skip_backend("jax") +def test_sinkhorn_variants(nx): # test sinkhorn n = 100 rng = np.random.RandomState(0) @@ -147,13 +143,18 @@ def test_sinkhorn_variants(): M = ot.dist(x, x) - G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10) - Ges = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10) - G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10) + ub = nx.from_numpy(u) + Mb = nx.from_numpy(M) + + G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) + G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Ges = nx.to_numpy(ot.sinkhorn( + ub, ub, Mb, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) + G_green = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='greenkhorn', stopThr=1e-10)) # check values + np.testing.assert_allclose(G, G0, atol=1e-05) np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) @@ -184,7 +185,7 @@ def test_sinkhorn_variants_log(): @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_barycenter(method): +def test_barycenter(nx, method): n_bins = 100 # nb bins # Gaussian distributions @@ -201,16 +202,23 @@ def test_barycenter(method): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) + Ab = nx.from_numpy(A) + Mb = nx.from_numpy(M) + weightsb = nx.from_numpy(weights) + # wasserstein reg = 1e-2 - bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True) + bary_wass_np, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True) + bary_wass, _ = ot.bregman.barycenter(Ab, Mb, reg, weightsb, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) - ot.bregman.barycenter(A, M, reg, log=True, verbose=True) + ot.bregman.barycenter(Ab, Mb, reg, log=True, verbose=True) -def test_barycenter_stabilization(): +def test_barycenter_stabilization(nx): n_bins = 100 # nb bins # Gaussian distributions @@ -227,17 +235,26 @@ def test_barycenter_stabilization(): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) + Ab = nx.from_numpy(A) + Mb = nx.from_numpy(M) + weights_b = nx.from_numpy(weights) + # wasserstein reg = 1e-2 - bar_stable = ot.bregman.barycenter(A, M, reg, weights, - method="sinkhorn_stabilized", - stopThr=1e-8, verbose=True) - bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", - stopThr=1e-8, verbose=True) + bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) + bar_stable = nx.to_numpy(ot.bregman.barycenter( + Ab, Mb, reg, weights_b, method="sinkhorn_stabilized", + stopThr=1e-8, verbose=True + )) + bar = nx.to_numpy(ot.bregman.barycenter( + Ab, Mb, reg, weights_b, method="sinkhorn", + stopThr=1e-8, verbose=True + )) np.testing.assert_allclose(bar, bar_stable) + np.testing.assert_allclose(bar, bar_np) -def test_wasserstein_bary_2d(): +def test_wasserstein_bary_2d(nx): size = 100 # size of a square image a1 = np.random.randn(size, size) a1 += a1.min() @@ -250,17 +267,21 @@ def test_wasserstein_bary_2d(): A[0, :, :] = a1 A[1, :, :] = a2 + Ab = nx.from_numpy(A) + # wasserstein reg = 1e-2 - bary_wass = ot.bregman.convolutional_barycenter2d(A, reg) + bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg)) np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) # help in checking if log and verbose do not bug the function ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) -def test_unmix(): +def test_unmix(nx): n_bins = 50 # nb bins # Gaussian distributions @@ -280,18 +301,26 @@ def test_unmix(): M0 /= M0.max() h0 = ot.unif(2) + ab = nx.from_numpy(a) + Db = nx.from_numpy(D) + Mb = nx.from_numpy(M) + M0b = nx.from_numpy(M0) + h0b = nx.from_numpy(h0) + # wasserstein reg = 1e-3 - um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, ) + um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) + um = nx.to_numpy(ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) + np.testing.assert_allclose(um, um_np) - ot.bregman.unmix(a, D, M, M0, h0, reg, + ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01, log=True, verbose=True) -def test_empirical_sinkhorn(): +def test_empirical_sinkhorn(nx): # test sinkhorn n = 10 a = ot.unif(n) @@ -302,19 +331,28 @@ def test_empirical_sinkhorn(): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='minkowski') - G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) - sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + X_sb = nx.from_numpy(X_s) + X_tb = nx.from_numpy(X_t) + Mb = nx.from_numpy(M, type_as=ab) + M_mb = nx.from_numpy(M_m, type_as=ab) + + G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) - G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True) - sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True) + G_log = nx.to_numpy(G_log) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski') - sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski')) + sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1) - loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) - # check constratints + # check constraints np.testing.assert_allclose( sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian np.testing.assert_allclose( @@ -330,7 +368,7 @@ def test_empirical_sinkhorn(): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) -def test_lazy_empirical_sinkhorn(): +def test_lazy_empirical_sinkhorn(nx): # test sinkhorn n = 10 a = ot.unif(n) @@ -342,22 +380,34 @@ def test_lazy_empirical_sinkhorn(): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='minkowski') - f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + X_sb = nx.from_numpy(X_s) + X_tb = nx.from_numpy(X_t) + Mb = nx.from_numpy(M, type_as=ab) + M_mb = nx.from_numpy(M_m, type_as=ab) + + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) + f, g = nx.to_numpy(f), nx.to_numpy(g) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) - sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) - f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + f, g = nx.to_numpy(f), nx.to_numpy(g) G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) - sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log = nx.to_numpy(sinkhorn_log) - f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = nx.to_numpy(f), nx.to_numpy(g) G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) - sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) + sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) - loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) - # check constratints + # check constraints np.testing.assert_allclose( sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian np.testing.assert_allclose( @@ -373,7 +423,7 @@ def test_lazy_empirical_sinkhorn(): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) -def test_empirical_sinkhorn_divergence(): +def test_empirical_sinkhorn_divergence(nx): # Test sinkhorn divergence n = 10 a = np.linspace(1, n, n) @@ -385,22 +435,31 @@ def test_empirical_sinkhorn_divergence(): M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) - sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1)) + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + X_sb = nx.from_numpy(X_s) + X_tb = nx.from_numpy(X_t) + Mb = nx.from_numpy(M, type_as=ab) + M_sb = nx.from_numpy(M_s, type_as=ab) + M_tb = nx.from_numpy(M_t, type_as=ab) + + emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) + sinkhorn_div = nx.to_numpy( + ot.sinkhorn2(ab, bb, Mb, 1) + - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) + - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) + ) + emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) - emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b, log=True) - sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True) - sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True) - sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True) - sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b) # check constraints + np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) np.testing.assert_allclose( emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn - np.testing.assert_allclose( - emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn + ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True) -def test_stabilized_vs_sinkhorn_multidim(): + +def test_stabilized_vs_sinkhorn_multidim(nx): # test if stable version matches sinkhorn # for multidimensional inputs n = 100 @@ -416,12 +475,21 @@ def test_stabilized_vs_sinkhorn_multidim(): M = ot.utils.dist0(n) M /= np.median(M) epsilon = 0.1 - G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon, + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + + G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) + G, log = ot.bregman.sinkhorn(ab, bb, Mb, reg=epsilon, method="sinkhorn_stabilized", log=True) - G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon, + G = nx.to_numpy(G) + G2, log2 = ot.bregman.sinkhorn(ab, bb, Mb, epsilon, method="sinkhorn", log=True) + G2 = nx.to_numpy(G2) + np.testing.assert_allclose(G_np, G2) np.testing.assert_allclose(G, G2) @@ -458,8 +526,9 @@ def test_implemented_methods(): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) +@pytest.skip_backend("jax") @pytest.mark.filterwarnings("ignore:Bottleneck") -def test_screenkhorn(): +def test_screenkhorn(nx): # test screenkhorn rng = np.random.RandomState(0) n = 100 @@ -468,17 +537,31 @@ def test_screenkhorn(): x = rng.randn(n, 2) M = ot.dist(x, x) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + + # np sinkhorn + G_sink_np = ot.sinkhorn(a, b, M, 1e-03) # sinkhorn - G_sink = ot.sinkhorn(a, b, M, 1e-03) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1e-03)) # screenkhorn - G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, Mb, 1e-03, uniform=True, verbose=True)) # check marginals + np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) -def test_convolutional_barycenter_non_square(): +def test_convolutional_barycenter_non_square(nx): # test for image with height not equal width A = np.ones((2, 2, 3)) / (2 * 3) - b = ot.bregman.convolutional_barycenter2d(A, 1e-03) + Ab = nx.from_numpy(A) + + b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03) + b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, 1e-03)) + + np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) + np.testing.assert_allclose(b, b_np) -- cgit v1.2.3