summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-10-25 11:36:21 +0200
committerGitHub <noreply@github.com>2021-10-25 11:36:21 +0200
commit7a65086dd340265d0223eb8ffb5c9a5152a82dff (patch)
tree300f4a1cd645516fba1e440691fe48830d781b5c /test
parent7af8c2147d61349f4d99ca33318a8a125e4569aa (diff)
[MRG] Bregman backend (#280)
* Bregman * Resolve conflicts * Bug solve * Bregman updated for JAX compatibility * Tests coherence between backend improved * No longer enforcing 64 bits operations on Jax except for tests * Now using mixtures, to make backend dependent tests with less code * Better test skipping code * Pep8 + test optimizations * redundancy removed * Docs * Typo corrected * Typo * Typo * Docs * Docs * pep8 * Backend docs * Prettier docs * Mistake corrected * small changes * Better wording Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r--test/conftest.py49
-rw-r--r--test/test_backend.py102
-rw-r--r--test/test_bregman.py217
-rwxr-xr-xtest/test_partial.py6
-rw-r--r--test/test_smooth.py12
-rw-r--r--test/test_stochastic.py12
6 files changed, 316 insertions, 82 deletions
diff --git a/test/conftest.py b/test/conftest.py
new file mode 100644
index 0000000..876b525
--- /dev/null
+++ b/test/conftest.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+
+# Configuration file for pytest
+
+# License: MIT License
+
+import pytest
+from ot.backend import jax
+from ot.backend import get_backend_list
+import functools
+
+if jax:
+ from jax.config import config
+
+backend_list = get_backend_list()
+
+
+@pytest.fixture(params=backend_list)
+def nx(request):
+ backend = request.param
+ if backend.__name__ == "jax":
+ config.update("jax_enable_x64", True)
+
+ yield backend
+
+ if backend.__name__ == "jax":
+ config.update("jax_enable_x64", False)
+
+
+def skip_arg(arg, value, reason=None, getter=lambda x: x):
+ if reason is None:
+ reason = f"Param {arg} should be skipped for value {value}"
+
+ def wrapper(function):
+
+ @functools.wraps(function)
+ def wrapped(*args, **kwargs):
+ if arg in kwargs.keys() and getter(kwargs[arg]) == value:
+ pytest.skip(reason)
+ return function(*args, **kwargs)
+
+ return wrapped
+
+ return wrapper
+
+
+def pytest_configure(config):
+ pytest.skip_arg = skip_arg
+ pytest.skip_backend = functools.partial(skip_arg, "nx", getter=str)
diff --git a/test/test_backend.py b/test/test_backend.py
index cbfaf94..859da5a 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -1,6 +1,7 @@
"""Tests for backend module """
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
@@ -156,6 +157,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.sqrt(M)
with pytest.raises(NotImplementedError):
+ nx.power(v, 2)
+ with pytest.raises(NotImplementedError):
nx.dot(v, v)
with pytest.raises(NotImplementedError):
nx.norm(M)
@@ -174,7 +177,37 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.argsort(M)
with pytest.raises(NotImplementedError):
+ nx.searchsorted(v, v)
+ with pytest.raises(NotImplementedError):
nx.flip(M)
+ with pytest.raises(NotImplementedError):
+ nx.clip(M, -1, 1)
+ with pytest.raises(NotImplementedError):
+ nx.repeat(M, 0, 1)
+ with pytest.raises(NotImplementedError):
+ nx.take_along_axis(M, v, 0)
+ with pytest.raises(NotImplementedError):
+ nx.concatenate([v, v])
+ with pytest.raises(NotImplementedError):
+ nx.zero_pad(M, v)
+ with pytest.raises(NotImplementedError):
+ nx.argmax(M)
+ with pytest.raises(NotImplementedError):
+ nx.mean(M)
+ with pytest.raises(NotImplementedError):
+ nx.std(M)
+ with pytest.raises(NotImplementedError):
+ nx.linspace(0, 1, 50)
+ with pytest.raises(NotImplementedError):
+ nx.meshgrid(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.diag(M)
+ with pytest.raises(NotImplementedError):
+ nx.unique([M, M])
+ with pytest.raises(NotImplementedError):
+ nx.logsumexp(M)
+ with pytest.raises(NotImplementedError):
+ nx.stack([M, M])
@pytest.mark.parametrize('backend', backend_list)
@@ -278,6 +311,10 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('sqrt')
+ A = nx.power(Mb, 2)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('power')
+
A = nx.dot(vb, vb)
lst_b.append(nx.to_numpy(A))
lst_name.append('dot(v,v)')
@@ -326,10 +363,75 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('argsort')
+ A = nx.searchsorted(Mb, Mb, 'right')
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('searchsorted')
+
A = nx.flip(Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append('flip')
+ A = nx.clip(vb, 0, 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('clip')
+
+ A = nx.repeat(Mb, 0)
+ A = nx.repeat(Mb, 2, -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('repeat')
+
+ A = nx.take_along_axis(vb, nx.arange(3), -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('take_along_axis')
+
+ A = nx.concatenate((Mb, Mb), -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('concatenate')
+
+ A = nx.zero_pad(Mb, len(Mb.shape) * [(3, 3)])
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('zero_pad')
+
+ A = nx.argmax(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argmax')
+
+ A = nx.mean(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('mean')
+
+ A = nx.std(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('std')
+
+ A = nx.linspace(0, 1, 50)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('linspace')
+
+ X, Y = nx.meshgrid(vb, vb)
+ lst_b.append(np.stack([nx.to_numpy(X), nx.to_numpy(Y)]))
+ lst_name.append('meshgrid')
+
+ A = nx.diag(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('diag2D')
+
+ A = nx.diag(vb, 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('diag1D')
+
+ A = nx.unique(nx.from_numpy(np.stack([M, M])))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('unique')
+
+ A = nx.logsumexp(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('logsumexp')
+
+ A = nx.stack([Mb, Mb])
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('stack')
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 88166a5..942cb6d 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -10,11 +10,8 @@ import numpy as np
import pytest
import ot
-from ot.backend import get_backend_list
from ot.backend import torch
-backend_list = get_backend_list()
-
def test_sinkhorn():
# test sinkhorn
@@ -28,14 +25,13 @@ def test_sinkhorn():
G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
-@pytest.mark.parametrize('nx', backend_list)
def test_sinkhorn_backends(nx):
n_samples = 100
n_features = 2
@@ -57,7 +53,6 @@ def test_sinkhorn_backends(nx):
np.allclose(G, nx.to_numpy(Gb))
-@pytest.mark.parametrize('nx', backend_list)
def test_sinkhorn2_backends(nx):
n_samples = 100
n_features = 2
@@ -116,20 +111,20 @@ def test_sinkhorn_empty():
M = ot.dist(x, x)
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10,
method='sinkhorn_stabilized', verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
G, log = ot.sinkhorn(
[], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling',
verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
@@ -137,7 +132,8 @@ def test_sinkhorn_empty():
ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
-def test_sinkhorn_variants():
+@pytest.skip_backend("jax")
+def test_sinkhorn_variants(nx):
# test sinkhorn
n = 100
rng = np.random.RandomState(0)
@@ -147,13 +143,18 @@ def test_sinkhorn_variants():
M = ot.dist(x, x)
- G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
- Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
- Ges = ot.sinkhorn(
- u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
- G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
+ ub = nx.from_numpy(u)
+ Mb = nx.from_numpy(M)
+
+ G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
+ G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Ges = nx.to_numpy(ot.sinkhorn(
+ ub, ub, Mb, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10))
+ G_green = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='greenkhorn', stopThr=1e-10))
# check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
@@ -184,7 +185,7 @@ def test_sinkhorn_variants_log():
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_barycenter(method):
+def test_barycenter(nx, method):
n_bins = 100 # nb bins
# Gaussian distributions
@@ -201,16 +202,23 @@ def test_barycenter(method):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
+ Ab = nx.from_numpy(A)
+ Mb = nx.from_numpy(M)
+ weightsb = nx.from_numpy(weights)
+
# wasserstein
reg = 1e-2
- bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
+ bary_wass_np, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
+ bary_wass, _ = ot.bregman.barycenter(Ab, Mb, reg, weightsb, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(bary_wass, bary_wass_np)
- ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
+ ot.bregman.barycenter(Ab, Mb, reg, log=True, verbose=True)
-def test_barycenter_stabilization():
+def test_barycenter_stabilization(nx):
n_bins = 100 # nb bins
# Gaussian distributions
@@ -227,17 +235,26 @@ def test_barycenter_stabilization():
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
+ Ab = nx.from_numpy(A)
+ Mb = nx.from_numpy(M)
+ weights_b = nx.from_numpy(weights)
+
# wasserstein
reg = 1e-2
- bar_stable = ot.bregman.barycenter(A, M, reg, weights,
- method="sinkhorn_stabilized",
- stopThr=1e-8, verbose=True)
- bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
- stopThr=1e-8, verbose=True)
+ bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
+ bar_stable = nx.to_numpy(ot.bregman.barycenter(
+ Ab, Mb, reg, weights_b, method="sinkhorn_stabilized",
+ stopThr=1e-8, verbose=True
+ ))
+ bar = nx.to_numpy(ot.bregman.barycenter(
+ Ab, Mb, reg, weights_b, method="sinkhorn",
+ stopThr=1e-8, verbose=True
+ ))
np.testing.assert_allclose(bar, bar_stable)
+ np.testing.assert_allclose(bar, bar_np)
-def test_wasserstein_bary_2d():
+def test_wasserstein_bary_2d(nx):
size = 100 # size of a square image
a1 = np.random.randn(size, size)
a1 += a1.min()
@@ -250,17 +267,21 @@ def test_wasserstein_bary_2d():
A[0, :, :] = a1
A[1, :, :] = a2
+ Ab = nx.from_numpy(A)
+
# wasserstein
reg = 1e-2
- bary_wass = ot.bregman.convolutional_barycenter2d(A, reg)
+ bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg)
+ bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg))
np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(bary_wass, bary_wass_np)
# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
-def test_unmix():
+def test_unmix(nx):
n_bins = 50 # nb bins
# Gaussian distributions
@@ -280,18 +301,26 @@ def test_unmix():
M0 /= M0.max()
h0 = ot.unif(2)
+ ab = nx.from_numpy(a)
+ Db = nx.from_numpy(D)
+ Mb = nx.from_numpy(M)
+ M0b = nx.from_numpy(M0)
+ h0b = nx.from_numpy(h0)
+
# wasserstein
reg = 1e-3
- um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, )
+ um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01)
+ um = nx.to_numpy(ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01))
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
+ np.testing.assert_allclose(um, um_np)
- ot.bregman.unmix(a, D, M, M0, h0, reg,
+ ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg,
1, alpha=0.01, log=True, verbose=True)
-def test_empirical_sinkhorn():
+def test_empirical_sinkhorn(nx):
# test sinkhorn
n = 10
a = ot.unif(n)
@@ -302,19 +331,28 @@ def test_empirical_sinkhorn():
M = ot.dist(X_s, X_t)
M_m = ot.dist(X_s, X_t, metric='minkowski')
- G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
- sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ Mb = nx.from_numpy(M, type_as=ab)
+ M_mb = nx.from_numpy(M_m, type_as=ab)
+
+ G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1))
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1))
- G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True)
- sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+ G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True)
+ G_log = nx.to_numpy(G_log)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
+ sinkhorn_log = nx.to_numpy(sinkhorn_log)
- G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski')
- sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
+ G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski'))
+ sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1)
- loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+ loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1))
- # check constratints
+ # check constraints
np.testing.assert_allclose(
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
np.testing.assert_allclose(
@@ -330,7 +368,7 @@ def test_empirical_sinkhorn():
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
-def test_lazy_empirical_sinkhorn():
+def test_lazy_empirical_sinkhorn(nx):
# test sinkhorn
n = 10
a = ot.unif(n)
@@ -342,22 +380,34 @@ def test_lazy_empirical_sinkhorn():
M = ot.dist(X_s, X_t)
M_m = ot.dist(X_s, X_t, metric='minkowski')
- f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ Mb = nx.from_numpy(M, type_as=ab)
+ M_mb = nx.from_numpy(M_m, type_as=ab)
+
+ f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
- sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1))
- f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
- sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
+ sinkhorn_log = nx.to_numpy(sinkhorn_log)
- f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
- sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
+ sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
- loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+ loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1))
- # check constratints
+ # check constraints
np.testing.assert_allclose(
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
np.testing.assert_allclose(
@@ -373,7 +423,7 @@ def test_lazy_empirical_sinkhorn():
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
-def test_empirical_sinkhorn_divergence():
+def test_empirical_sinkhorn_divergence(nx):
# Test sinkhorn divergence
n = 10
a = np.linspace(1, n, n)
@@ -385,22 +435,31 @@ def test_empirical_sinkhorn_divergence():
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b)
- sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1))
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ Mb = nx.from_numpy(M, type_as=ab)
+ M_sb = nx.from_numpy(M_s, type_as=ab)
+ M_tb = nx.from_numpy(M_t, type_as=ab)
+
+ emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
+ sinkhorn_div = nx.to_numpy(
+ ot.sinkhorn2(ab, bb, Mb, 1)
+ - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1)
+ - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1)
+ )
+ emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b)
- emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b, log=True)
- sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True)
- sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True)
- sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True)
- sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b)
# check constraints
+ np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05)
np.testing.assert_allclose(
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
- np.testing.assert_allclose(
- emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn
+ ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True)
-def test_stabilized_vs_sinkhorn_multidim():
+
+def test_stabilized_vs_sinkhorn_multidim(nx):
# test if stable version matches sinkhorn
# for multidimensional inputs
n = 100
@@ -416,12 +475,21 @@ def test_stabilized_vs_sinkhorn_multidim():
M = ot.utils.dist0(n)
M /= np.median(M)
epsilon = 0.1
- G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon,
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
+ G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
+ G, log = ot.bregman.sinkhorn(ab, bb, Mb, reg=epsilon,
method="sinkhorn_stabilized",
log=True)
- G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon,
+ G = nx.to_numpy(G)
+ G2, log2 = ot.bregman.sinkhorn(ab, bb, Mb, epsilon,
method="sinkhorn", log=True)
+ G2 = nx.to_numpy(G2)
+ np.testing.assert_allclose(G_np, G2)
np.testing.assert_allclose(G, G2)
@@ -458,8 +526,9 @@ def test_implemented_methods():
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+@pytest.skip_backend("jax")
@pytest.mark.filterwarnings("ignore:Bottleneck")
-def test_screenkhorn():
+def test_screenkhorn(nx):
# test screenkhorn
rng = np.random.RandomState(0)
n = 100
@@ -468,17 +537,31 @@ def test_screenkhorn():
x = rng.randn(n, 2)
M = ot.dist(x, x)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
+ # np sinkhorn
+ G_sink_np = ot.sinkhorn(a, b, M, 1e-03)
# sinkhorn
- G_sink = ot.sinkhorn(a, b, M, 1e-03)
+ G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1e-03))
# screenkhorn
- G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True)
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, Mb, 1e-03, uniform=True, verbose=True))
# check marginals
+ np.testing.assert_allclose(G_sink_np, G_sink)
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
-def test_convolutional_barycenter_non_square():
+def test_convolutional_barycenter_non_square(nx):
# test for image with height not equal width
A = np.ones((2, 2, 3)) / (2 * 3)
- b = ot.bregman.convolutional_barycenter2d(A, 1e-03)
+ Ab = nx.from_numpy(A)
+
+ b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03)
+ b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, 1e-03))
+
+ np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
+ np.testing.assert_allclose(b, b_np)
diff --git a/test/test_partial.py b/test/test_partial.py
index 3571e2a..97c611b 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -104,7 +104,7 @@ def test_partial_wasserstein():
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
log=True, verbose=True)
- # check constratints
+ # check constraints
np.testing.assert_equal(
w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
@@ -127,7 +127,7 @@ def test_partial_wasserstein():
np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
np.testing.assert_equal(
G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
@@ -194,7 +194,7 @@ def test_partial_gromov_wasserstein():
100, m=m,
log=True)
- # check constratints
+ # check constraints
np.testing.assert_equal(
res0.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
diff --git a/test/test_smooth.py b/test/test_smooth.py
index 2afa4f8..31e0b2e 100644
--- a/test/test_smooth.py
+++ b/test/test_smooth.py
@@ -25,16 +25,16 @@ def test_smooth_ot_dual():
Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
- # kl regyularisation
+ # kl regularisation
G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
@@ -60,16 +60,16 @@ def test_smooth_ot_semi_dual():
Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
- # kl regyularisation
+ # kl regularisation
G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index 98e93ec..736df32 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -43,7 +43,7 @@ def test_stochastic_sag():
G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag",
numItermax=numItermax)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-03) # cf convergence sag
np.testing.assert_allclose(
@@ -73,7 +73,7 @@ def test_stochastic_asgd():
G, log = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
numItermax=numItermax, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-02) # cf convergence asgd
np.testing.assert_allclose(
@@ -105,7 +105,7 @@ def test_sag_asgd_sinkhorn():
numItermax=nb_iter)
G_sinkhorn = ot.sinkhorn(u, u, M, reg)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-02)
np.testing.assert_allclose(
@@ -148,7 +148,7 @@ def test_stochastic_dual_sgd():
G, log = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
numItermax=numItermax, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-03) # cf convergence sgd
np.testing.assert_allclose(
@@ -181,7 +181,7 @@ def test_dual_sgd_sinkhorn():
G_sinkhorn = ot.sinkhorn(u, u, M, reg)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-02)
np.testing.assert_allclose(
@@ -206,7 +206,7 @@ def test_dual_sgd_sinkhorn():
G_sinkhorn = ot.sinkhorn(a, b, M, reg)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
np.testing.assert_allclose(