diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-10-25 11:36:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-25 11:36:21 +0200 |
commit | 7a65086dd340265d0223eb8ffb5c9a5152a82dff (patch) | |
tree | 300f4a1cd645516fba1e440691fe48830d781b5c /test | |
parent | 7af8c2147d61349f4d99ca33318a8a125e4569aa (diff) |
[MRG] Bregman backend (#280)
* Bregman
* Resolve conflicts
* Bug solve
* Bregman updated for JAX compatibility
* Tests coherence between backend improved
* No longer enforcing 64 bits operations on Jax except for tests
* Now using mixtures, to make backend dependent tests with less code
* Better test skipping code
* Pep8 + test optimizations
* redundancy removed
* Docs
* Typo corrected
* Typo
* Typo
* Docs
* Docs
* pep8
* Backend docs
* Prettier docs
* Mistake corrected
* small changes
* Better wording
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r-- | test/conftest.py | 49 | ||||
-rw-r--r-- | test/test_backend.py | 102 | ||||
-rw-r--r-- | test/test_bregman.py | 217 | ||||
-rwxr-xr-x | test/test_partial.py | 6 | ||||
-rw-r--r-- | test/test_smooth.py | 12 | ||||
-rw-r--r-- | test/test_stochastic.py | 12 |
6 files changed, 316 insertions, 82 deletions
diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..876b525 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- + +# Configuration file for pytest + +# License: MIT License + +import pytest +from ot.backend import jax +from ot.backend import get_backend_list +import functools + +if jax: + from jax.config import config + +backend_list = get_backend_list() + + +@pytest.fixture(params=backend_list) +def nx(request): + backend = request.param + if backend.__name__ == "jax": + config.update("jax_enable_x64", True) + + yield backend + + if backend.__name__ == "jax": + config.update("jax_enable_x64", False) + + +def skip_arg(arg, value, reason=None, getter=lambda x: x): + if reason is None: + reason = f"Param {arg} should be skipped for value {value}" + + def wrapper(function): + + @functools.wraps(function) + def wrapped(*args, **kwargs): + if arg in kwargs.keys() and getter(kwargs[arg]) == value: + pytest.skip(reason) + return function(*args, **kwargs) + + return wrapped + + return wrapper + + +def pytest_configure(config): + pytest.skip_arg = skip_arg + pytest.skip_backend = functools.partial(skip_arg, "nx", getter=str) diff --git a/test/test_backend.py b/test/test_backend.py index cbfaf94..859da5a 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -1,6 +1,7 @@ """Tests for backend module """ # Author: Remi Flamary <remi.flamary@polytechnique.edu> +# Nicolas Courty <ncourty@irisa.fr> # # License: MIT License @@ -156,6 +157,8 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.sqrt(M) with pytest.raises(NotImplementedError): + nx.power(v, 2) + with pytest.raises(NotImplementedError): nx.dot(v, v) with pytest.raises(NotImplementedError): nx.norm(M) @@ -174,7 +177,37 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.argsort(M) with pytest.raises(NotImplementedError): + nx.searchsorted(v, v) + with pytest.raises(NotImplementedError): nx.flip(M) + with pytest.raises(NotImplementedError): + nx.clip(M, -1, 1) + with pytest.raises(NotImplementedError): + nx.repeat(M, 0, 1) + with pytest.raises(NotImplementedError): + nx.take_along_axis(M, v, 0) + with pytest.raises(NotImplementedError): + nx.concatenate([v, v]) + with pytest.raises(NotImplementedError): + nx.zero_pad(M, v) + with pytest.raises(NotImplementedError): + nx.argmax(M) + with pytest.raises(NotImplementedError): + nx.mean(M) + with pytest.raises(NotImplementedError): + nx.std(M) + with pytest.raises(NotImplementedError): + nx.linspace(0, 1, 50) + with pytest.raises(NotImplementedError): + nx.meshgrid(v, v) + with pytest.raises(NotImplementedError): + nx.diag(M) + with pytest.raises(NotImplementedError): + nx.unique([M, M]) + with pytest.raises(NotImplementedError): + nx.logsumexp(M) + with pytest.raises(NotImplementedError): + nx.stack([M, M]) @pytest.mark.parametrize('backend', backend_list) @@ -278,6 +311,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('sqrt') + A = nx.power(Mb, 2) + lst_b.append(nx.to_numpy(A)) + lst_name.append('power') + A = nx.dot(vb, vb) lst_b.append(nx.to_numpy(A)) lst_name.append('dot(v,v)') @@ -326,10 +363,75 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('argsort') + A = nx.searchsorted(Mb, Mb, 'right') + lst_b.append(nx.to_numpy(A)) + lst_name.append('searchsorted') + A = nx.flip(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('flip') + A = nx.clip(vb, 0, 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('clip') + + A = nx.repeat(Mb, 0) + A = nx.repeat(Mb, 2, -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('repeat') + + A = nx.take_along_axis(vb, nx.arange(3), -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('take_along_axis') + + A = nx.concatenate((Mb, Mb), -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('concatenate') + + A = nx.zero_pad(Mb, len(Mb.shape) * [(3, 3)]) + lst_b.append(nx.to_numpy(A)) + lst_name.append('zero_pad') + + A = nx.argmax(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argmax') + + A = nx.mean(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('mean') + + A = nx.std(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('std') + + A = nx.linspace(0, 1, 50) + lst_b.append(nx.to_numpy(A)) + lst_name.append('linspace') + + X, Y = nx.meshgrid(vb, vb) + lst_b.append(np.stack([nx.to_numpy(X), nx.to_numpy(Y)])) + lst_name.append('meshgrid') + + A = nx.diag(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('diag2D') + + A = nx.diag(vb, 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('diag1D') + + A = nx.unique(nx.from_numpy(np.stack([M, M]))) + lst_b.append(nx.to_numpy(A)) + lst_name.append('unique') + + A = nx.logsumexp(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('logsumexp') + + A = nx.stack([Mb, Mb]) + lst_b.append(nx.to_numpy(A)) + lst_name.append('stack') + lst_tot.append(lst_b) lst_np = lst_tot[0] diff --git a/test/test_bregman.py b/test/test_bregman.py index 88166a5..942cb6d 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -10,11 +10,8 @@ import numpy as np import pytest import ot -from ot.backend import get_backend_list from ot.backend import torch -backend_list = get_backend_list() - def test_sinkhorn(): # test sinkhorn @@ -28,14 +25,13 @@ def test_sinkhorn(): G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( u, G.sum(0), atol=1e-05) # cf convergence sinkhorn -@pytest.mark.parametrize('nx', backend_list) def test_sinkhorn_backends(nx): n_samples = 100 n_features = 2 @@ -57,7 +53,6 @@ def test_sinkhorn_backends(nx): np.allclose(G, nx.to_numpy(Gb)) -@pytest.mark.parametrize('nx', backend_list) def test_sinkhorn2_backends(nx): n_samples = 100 n_features = 2 @@ -116,20 +111,20 @@ def test_sinkhorn_empty(): M = ot.dist(x, x) G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method='sinkhorn_stabilized', verbose=True, log=True) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) G, log = ot.sinkhorn( [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling', verbose=True, log=True) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) @@ -137,7 +132,8 @@ def test_sinkhorn_empty(): ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True) -def test_sinkhorn_variants(): +@pytest.skip_backend("jax") +def test_sinkhorn_variants(nx): # test sinkhorn n = 100 rng = np.random.RandomState(0) @@ -147,13 +143,18 @@ def test_sinkhorn_variants(): M = ot.dist(x, x) - G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10) - Ges = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10) - G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10) + ub = nx.from_numpy(u) + Mb = nx.from_numpy(M) + + G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) + G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Ges = nx.to_numpy(ot.sinkhorn( + ub, ub, Mb, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) + G_green = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='greenkhorn', stopThr=1e-10)) # check values + np.testing.assert_allclose(G, G0, atol=1e-05) np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) @@ -184,7 +185,7 @@ def test_sinkhorn_variants_log(): @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_barycenter(method): +def test_barycenter(nx, method): n_bins = 100 # nb bins # Gaussian distributions @@ -201,16 +202,23 @@ def test_barycenter(method): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) + Ab = nx.from_numpy(A) + Mb = nx.from_numpy(M) + weightsb = nx.from_numpy(weights) + # wasserstein reg = 1e-2 - bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True) + bary_wass_np, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True) + bary_wass, _ = ot.bregman.barycenter(Ab, Mb, reg, weightsb, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) - ot.bregman.barycenter(A, M, reg, log=True, verbose=True) + ot.bregman.barycenter(Ab, Mb, reg, log=True, verbose=True) -def test_barycenter_stabilization(): +def test_barycenter_stabilization(nx): n_bins = 100 # nb bins # Gaussian distributions @@ -227,17 +235,26 @@ def test_barycenter_stabilization(): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) + Ab = nx.from_numpy(A) + Mb = nx.from_numpy(M) + weights_b = nx.from_numpy(weights) + # wasserstein reg = 1e-2 - bar_stable = ot.bregman.barycenter(A, M, reg, weights, - method="sinkhorn_stabilized", - stopThr=1e-8, verbose=True) - bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", - stopThr=1e-8, verbose=True) + bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) + bar_stable = nx.to_numpy(ot.bregman.barycenter( + Ab, Mb, reg, weights_b, method="sinkhorn_stabilized", + stopThr=1e-8, verbose=True + )) + bar = nx.to_numpy(ot.bregman.barycenter( + Ab, Mb, reg, weights_b, method="sinkhorn", + stopThr=1e-8, verbose=True + )) np.testing.assert_allclose(bar, bar_stable) + np.testing.assert_allclose(bar, bar_np) -def test_wasserstein_bary_2d(): +def test_wasserstein_bary_2d(nx): size = 100 # size of a square image a1 = np.random.randn(size, size) a1 += a1.min() @@ -250,17 +267,21 @@ def test_wasserstein_bary_2d(): A[0, :, :] = a1 A[1, :, :] = a2 + Ab = nx.from_numpy(A) + # wasserstein reg = 1e-2 - bary_wass = ot.bregman.convolutional_barycenter2d(A, reg) + bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg)) np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) # help in checking if log and verbose do not bug the function ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) -def test_unmix(): +def test_unmix(nx): n_bins = 50 # nb bins # Gaussian distributions @@ -280,18 +301,26 @@ def test_unmix(): M0 /= M0.max() h0 = ot.unif(2) + ab = nx.from_numpy(a) + Db = nx.from_numpy(D) + Mb = nx.from_numpy(M) + M0b = nx.from_numpy(M0) + h0b = nx.from_numpy(h0) + # wasserstein reg = 1e-3 - um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, ) + um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) + um = nx.to_numpy(ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) + np.testing.assert_allclose(um, um_np) - ot.bregman.unmix(a, D, M, M0, h0, reg, + ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01, log=True, verbose=True) -def test_empirical_sinkhorn(): +def test_empirical_sinkhorn(nx): # test sinkhorn n = 10 a = ot.unif(n) @@ -302,19 +331,28 @@ def test_empirical_sinkhorn(): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='minkowski') - G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) - sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + X_sb = nx.from_numpy(X_s) + X_tb = nx.from_numpy(X_t) + Mb = nx.from_numpy(M, type_as=ab) + M_mb = nx.from_numpy(M_m, type_as=ab) + + G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) - G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True) - sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True) + G_log = nx.to_numpy(G_log) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski') - sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski')) + sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1) - loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) - # check constratints + # check constraints np.testing.assert_allclose( sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian np.testing.assert_allclose( @@ -330,7 +368,7 @@ def test_empirical_sinkhorn(): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) -def test_lazy_empirical_sinkhorn(): +def test_lazy_empirical_sinkhorn(nx): # test sinkhorn n = 10 a = ot.unif(n) @@ -342,22 +380,34 @@ def test_lazy_empirical_sinkhorn(): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='minkowski') - f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + X_sb = nx.from_numpy(X_s) + X_tb = nx.from_numpy(X_t) + Mb = nx.from_numpy(M, type_as=ab) + M_mb = nx.from_numpy(M_m, type_as=ab) + + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) + f, g = nx.to_numpy(f), nx.to_numpy(g) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) - sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) - f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + f, g = nx.to_numpy(f), nx.to_numpy(g) G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) - sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log = nx.to_numpy(sinkhorn_log) - f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = nx.to_numpy(f), nx.to_numpy(g) G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) - sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) + sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) - loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) - # check constratints + # check constraints np.testing.assert_allclose( sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian np.testing.assert_allclose( @@ -373,7 +423,7 @@ def test_lazy_empirical_sinkhorn(): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) -def test_empirical_sinkhorn_divergence(): +def test_empirical_sinkhorn_divergence(nx): # Test sinkhorn divergence n = 10 a = np.linspace(1, n, n) @@ -385,22 +435,31 @@ def test_empirical_sinkhorn_divergence(): M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) - sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1)) + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + X_sb = nx.from_numpy(X_s) + X_tb = nx.from_numpy(X_t) + Mb = nx.from_numpy(M, type_as=ab) + M_sb = nx.from_numpy(M_s, type_as=ab) + M_tb = nx.from_numpy(M_t, type_as=ab) + + emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) + sinkhorn_div = nx.to_numpy( + ot.sinkhorn2(ab, bb, Mb, 1) + - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) + - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) + ) + emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) - emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b, log=True) - sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True) - sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True) - sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True) - sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b) # check constraints + np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) np.testing.assert_allclose( emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn - np.testing.assert_allclose( - emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn + ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True) -def test_stabilized_vs_sinkhorn_multidim(): + +def test_stabilized_vs_sinkhorn_multidim(nx): # test if stable version matches sinkhorn # for multidimensional inputs n = 100 @@ -416,12 +475,21 @@ def test_stabilized_vs_sinkhorn_multidim(): M = ot.utils.dist0(n) M /= np.median(M) epsilon = 0.1 - G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon, + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + + G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) + G, log = ot.bregman.sinkhorn(ab, bb, Mb, reg=epsilon, method="sinkhorn_stabilized", log=True) - G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon, + G = nx.to_numpy(G) + G2, log2 = ot.bregman.sinkhorn(ab, bb, Mb, epsilon, method="sinkhorn", log=True) + G2 = nx.to_numpy(G2) + np.testing.assert_allclose(G_np, G2) np.testing.assert_allclose(G, G2) @@ -458,8 +526,9 @@ def test_implemented_methods(): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) +@pytest.skip_backend("jax") @pytest.mark.filterwarnings("ignore:Bottleneck") -def test_screenkhorn(): +def test_screenkhorn(nx): # test screenkhorn rng = np.random.RandomState(0) n = 100 @@ -468,17 +537,31 @@ def test_screenkhorn(): x = rng.randn(n, 2) M = ot.dist(x, x) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + + # np sinkhorn + G_sink_np = ot.sinkhorn(a, b, M, 1e-03) # sinkhorn - G_sink = ot.sinkhorn(a, b, M, 1e-03) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1e-03)) # screenkhorn - G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, Mb, 1e-03, uniform=True, verbose=True)) # check marginals + np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) -def test_convolutional_barycenter_non_square(): +def test_convolutional_barycenter_non_square(nx): # test for image with height not equal width A = np.ones((2, 2, 3)) / (2 * 3) - b = ot.bregman.convolutional_barycenter2d(A, 1e-03) + Ab = nx.from_numpy(A) + + b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03) + b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, 1e-03)) + + np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) + np.testing.assert_allclose(b, b_np) diff --git a/test/test_partial.py b/test/test_partial.py index 3571e2a..97c611b 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -104,7 +104,7 @@ def test_partial_wasserstein(): w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True) - # check constratints + # check constraints np.testing.assert_equal( w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein np.testing.assert_equal( @@ -127,7 +127,7 @@ def test_partial_wasserstein(): np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_equal( G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein np.testing.assert_equal( @@ -194,7 +194,7 @@ def test_partial_gromov_wasserstein(): 100, m=m, log=True) - # check constratints + # check constraints np.testing.assert_equal( res0.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein np.testing.assert_equal( diff --git a/test/test_smooth.py b/test/test_smooth.py index 2afa4f8..31e0b2e 100644 --- a/test/test_smooth.py +++ b/test/test_smooth.py @@ -25,16 +25,16 @@ def test_smooth_ot_dual(): Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn - # kl regyularisation + # kl regularisation G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( @@ -60,16 +60,16 @@ def test_smooth_ot_semi_dual(): Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn - # kl regyularisation + # kl regularisation G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 98e93ec..736df32 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -43,7 +43,7 @@ def test_stochastic_sag(): G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag", numItermax=numItermax) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-03) # cf convergence sag np.testing.assert_allclose( @@ -73,7 +73,7 @@ def test_stochastic_asgd(): G, log = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd", numItermax=numItermax, log=True) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-02) # cf convergence asgd np.testing.assert_allclose( @@ -105,7 +105,7 @@ def test_sag_asgd_sinkhorn(): numItermax=nb_iter) G_sinkhorn = ot.sinkhorn(u, u, M, reg) - # check constratints + # check constraints np.testing.assert_allclose( G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-02) np.testing.assert_allclose( @@ -148,7 +148,7 @@ def test_stochastic_dual_sgd(): G, log = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size, numItermax=numItermax, log=True) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-03) # cf convergence sgd np.testing.assert_allclose( @@ -181,7 +181,7 @@ def test_dual_sgd_sinkhorn(): G_sinkhorn = ot.sinkhorn(u, u, M, reg) - # check constratints + # check constraints np.testing.assert_allclose( G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-02) np.testing.assert_allclose( @@ -206,7 +206,7 @@ def test_dual_sgd_sinkhorn(): G_sinkhorn = ot.sinkhorn(a, b, M, reg) - # check constratints + # check constraints np.testing.assert_allclose( G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) np.testing.assert_allclose( |