summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2021-11-09 17:05:13 +0100
committerGard Spreemann <gspr@nonempty.org>2021-11-09 17:05:13 +0100
commita9fdc844907decddf54bed3ebeea8d8b2cf0fc5c (patch)
tree449a03fce8fafb78b6badd12b6e633f1e5d73a64 /test
parenta16b9471d7114ec08977479b7249efe747702b97 (diff)
parentf1628794d521a8dfa00af383b5e06cd6d34af619 (diff)
Merge tag '0.8.0' into dfsg/latest
Diffstat (limited to 'test')
-rw-r--r--test/conftest.py62
-rw-r--r--test/test_1d_solver.py172
-rw-r--r--test/test_backend.py577
-rw-r--r--test/test_bregman.py718
-rw-r--r--test/test_da.py24
-rw-r--r--test/test_dr.py62
-rw-r--r--test/test_gromov.py523
-rw-r--r--test/test_helpers.py26
-rw-r--r--test/test_optim.py103
-rw-r--r--test/test_ot.py183
-rwxr-xr-xtest/test_partial.py16
-rw-r--r--test/test_regpath.py64
-rw-r--r--test/test_sliced.py213
-rw-r--r--test/test_smooth.py12
-rw-r--r--test/test_stochastic.py52
-rw-r--r--test/test_unbalanced.py33
-rw-r--r--test/test_utils.py84
17 files changed, 2600 insertions, 324 deletions
diff --git a/test/conftest.py b/test/conftest.py
new file mode 100644
index 0000000..987d98e
--- /dev/null
+++ b/test/conftest.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+
+# Configuration file for pytest
+
+# License: MIT License
+
+import pytest
+from ot.backend import jax
+from ot.backend import get_backend_list
+import functools
+
+if jax:
+ from jax.config import config
+ config.update("jax_enable_x64", True)
+
+backend_list = get_backend_list()
+
+
+@pytest.fixture(params=backend_list)
+def nx(request):
+ backend = request.param
+
+ yield backend
+
+
+def skip_arg(arg, value, reason=None, getter=lambda x: x):
+ if isinstance(arg, tuple) or isinstance(arg, list):
+ n = len(arg)
+ else:
+ arg = (arg, )
+ n = 1
+ if n != 1 and (isinstance(value, tuple) or isinstance(value, list)):
+ pass
+ else:
+ value = (value, )
+ if isinstance(getter, tuple) or isinstance(value, list):
+ pass
+ else:
+ getter = [getter] * n
+
+ if reason is None:
+ reason = f"Param {arg} should be skipped for value {value}"
+
+ def wrapper(function):
+
+ @functools.wraps(function)
+ def wrapped(*args, **kwargs):
+ if all(
+ arg[i] in kwargs.keys() and getter[i](kwargs[arg[i]]) == value[i]
+ for i in range(n)
+ ):
+ pytest.skip(reason)
+ return function(*args, **kwargs)
+
+ return wrapped
+
+ return wrapper
+
+
+def pytest_configure(config):
+ pytest.skip_arg = skip_arg
+ pytest.skip_backend = functools.partial(skip_arg, "nx", getter=str)
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
new file mode 100644
index 0000000..cb85cb9
--- /dev/null
+++ b/test/test_1d_solver.py
@@ -0,0 +1,172 @@
+"""Tests for module 1d Wasserstein solver"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import pytest
+
+import ot
+from ot.lp import wasserstein_1d
+
+from ot.backend import get_backend_list
+from scipy.stats import wasserstein_distance
+
+backend_list = get_backend_list()
+
+
+def test_emd_1d_emd2_1d_with_weights():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ w_u = rng.uniform(0., 1., n)
+ w_u = w_u / w_u.sum()
+
+ w_v = rng.uniform(0., 1., m)
+ w_v = w_v / w_v.sum()
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd(w_u, w_v, M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(w_u, G.sum(1))
+ np.testing.assert_allclose(w_v, G.sum(0))
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_wasserstein_1d(nx):
+ from scipy.stats import wasserstein_distance
+
+ rng = np.random.RandomState(0)
+
+ n = 100
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+
+ # test 1 : wasserstein_1d should be close to scipy W_1 implementation
+ np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1),
+ wasserstein_distance(x, x, rho_u, rho_v))
+
+ # test 2 : wasserstein_1d should be close to one when only translating the support
+ np.testing.assert_almost_equal(wasserstein_1d(xb, xb + 1, p=2),
+ 1.)
+
+ # test 3 : arrays test
+ X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1)
+ Xb = nx.from_numpy(X)
+ res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2)
+ np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4)
+
+
+def test_wasserstein_1d_type_devices(nx):
+
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+ rho_ub = nx.from_numpy(rho_u, type_as=tp)
+ rho_vb = nx.from_numpy(rho_v, type_as=tp)
+
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+
+ nx.assert_same_dtype_device(xb, res)
+
+
+def test_emd_1d_emd2_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
+
+ # check G is similar
+ np.testing.assert_allclose(G, G_1d, atol=1e-15)
+
+ # check AssertionError is raised if called on non 1d arrays
+ u = np.random.randn(n, 2)
+ v = np.random.randn(m, 2)
+ with pytest.raises(AssertionError):
+ ot.emd_1d(u, v, [], [])
+
+
+def test_emd1d_type_devices(nx):
+
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+ rho_ub = nx.from_numpy(rho_u, type_as=tp)
+ rho_vb = nx.from_numpy(rho_v, type_as=tp)
+
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)
diff --git a/test/test_backend.py b/test/test_backend.py
new file mode 100644
index 0000000..1832b91
--- /dev/null
+++ b/test/test_backend.py
@@ -0,0 +1,577 @@
+"""Tests for backend module """
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import ot
+import ot.backend
+from ot.backend import torch, jax
+
+import pytest
+
+import numpy as np
+from numpy.testing import assert_array_almost_equal_nulp
+
+from ot.backend import get_backend, get_backend_list, to_numpy
+
+
+def test_get_backend_list():
+
+ lst = get_backend_list()
+
+ assert len(lst) > 0
+ assert isinstance(lst[0], ot.backend.NumpyBackend)
+
+
+def test_to_numpy(nx):
+
+ v = nx.zeros(10)
+ M = nx.ones((10, 10))
+
+ v2 = to_numpy(v)
+ assert isinstance(v2, np.ndarray)
+
+ v2, M2 = to_numpy(v, M)
+ assert isinstance(M2, np.ndarray)
+
+
+def test_get_backend():
+
+ A = np.zeros((3, 2))
+ B = np.zeros((3, 1))
+
+ nx = get_backend(A)
+ assert nx.__name__ == 'numpy'
+
+ nx = get_backend(A, B)
+ assert nx.__name__ == 'numpy'
+
+ # error if no parameters
+ with pytest.raises(ValueError):
+ get_backend()
+
+ # error if unknown types
+ with pytest.raises(ValueError):
+ get_backend(1, 2.0)
+
+ # test torch
+ if torch:
+
+ A2 = torch.from_numpy(A)
+ B2 = torch.from_numpy(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'torch'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'torch'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
+ if jax:
+
+ A2 = jax.numpy.array(A)
+ B2 = jax.numpy.array(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'jax'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'jax'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
+
+def test_convert_between_backends(nx):
+
+ A = np.zeros((3, 2))
+ B = np.zeros((3, 1))
+
+ A2 = nx.from_numpy(A)
+ B2 = nx.from_numpy(B)
+
+ assert isinstance(A2, nx.__type__)
+ assert isinstance(B2, nx.__type__)
+
+ nx2 = get_backend(A2, B2)
+
+ assert nx2.__name__ == nx.__name__
+
+ assert_array_almost_equal_nulp(nx.to_numpy(A2), A)
+ assert_array_almost_equal_nulp(nx.to_numpy(B2), B)
+
+
+def test_empty_backend():
+
+ rnd = np.random.RandomState(0)
+ M = rnd.randn(10, 3)
+ v = rnd.randn(3)
+
+ nx = ot.backend.Backend()
+
+ with pytest.raises(NotImplementedError):
+ nx.from_numpy(M)
+ with pytest.raises(NotImplementedError):
+ nx.to_numpy(M)
+ with pytest.raises(NotImplementedError):
+ nx.set_gradients(0, 0, 0)
+ with pytest.raises(NotImplementedError):
+ nx.zeros((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.ones((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.arange(10, 1, 2)
+ with pytest.raises(NotImplementedError):
+ nx.full((10, 3), 3.14)
+ with pytest.raises(NotImplementedError):
+ nx.eye((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.sum(M)
+ with pytest.raises(NotImplementedError):
+ nx.cumsum(M)
+ with pytest.raises(NotImplementedError):
+ nx.max(M)
+ with pytest.raises(NotImplementedError):
+ nx.min(M)
+ with pytest.raises(NotImplementedError):
+ nx.maximum(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.minimum(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.abs(M)
+ with pytest.raises(NotImplementedError):
+ nx.log(M)
+ with pytest.raises(NotImplementedError):
+ nx.exp(M)
+ with pytest.raises(NotImplementedError):
+ nx.sqrt(M)
+ with pytest.raises(NotImplementedError):
+ nx.power(v, 2)
+ with pytest.raises(NotImplementedError):
+ nx.dot(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.norm(M)
+ with pytest.raises(NotImplementedError):
+ nx.exp(M)
+ with pytest.raises(NotImplementedError):
+ nx.any(M)
+ with pytest.raises(NotImplementedError):
+ nx.isnan(M)
+ with pytest.raises(NotImplementedError):
+ nx.isinf(M)
+ with pytest.raises(NotImplementedError):
+ nx.einsum('ij->i', M)
+ with pytest.raises(NotImplementedError):
+ nx.sort(M)
+ with pytest.raises(NotImplementedError):
+ nx.argsort(M)
+ with pytest.raises(NotImplementedError):
+ nx.searchsorted(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.flip(M)
+ with pytest.raises(NotImplementedError):
+ nx.outer(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.clip(M, -1, 1)
+ with pytest.raises(NotImplementedError):
+ nx.repeat(M, 0, 1)
+ with pytest.raises(NotImplementedError):
+ nx.take_along_axis(M, v, 0)
+ with pytest.raises(NotImplementedError):
+ nx.concatenate([v, v])
+ with pytest.raises(NotImplementedError):
+ nx.zero_pad(M, v)
+ with pytest.raises(NotImplementedError):
+ nx.argmax(M)
+ with pytest.raises(NotImplementedError):
+ nx.mean(M)
+ with pytest.raises(NotImplementedError):
+ nx.std(M)
+ with pytest.raises(NotImplementedError):
+ nx.linspace(0, 1, 50)
+ with pytest.raises(NotImplementedError):
+ nx.meshgrid(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.diag(M)
+ with pytest.raises(NotImplementedError):
+ nx.unique([M, M])
+ with pytest.raises(NotImplementedError):
+ nx.logsumexp(M)
+ with pytest.raises(NotImplementedError):
+ nx.stack([M, M])
+ with pytest.raises(NotImplementedError):
+ nx.reshape(M, (5, 3, 2))
+ with pytest.raises(NotImplementedError):
+ nx.seed(42)
+ with pytest.raises(NotImplementedError):
+ nx.rand()
+ with pytest.raises(NotImplementedError):
+ nx.randn()
+ nx.coo_matrix(M, M, M)
+ with pytest.raises(NotImplementedError):
+ nx.issparse(M)
+ with pytest.raises(NotImplementedError):
+ nx.tocsr(M)
+ with pytest.raises(NotImplementedError):
+ nx.eliminate_zeros(M)
+ with pytest.raises(NotImplementedError):
+ nx.todense(M)
+ with pytest.raises(NotImplementedError):
+ nx.where(M, M, M)
+ with pytest.raises(NotImplementedError):
+ nx.copy(M)
+ with pytest.raises(NotImplementedError):
+ nx.allclose(M, M)
+
+
+def test_func_backends(nx):
+
+ rnd = np.random.RandomState(0)
+ M = rnd.randn(10, 3)
+ v = rnd.randn(3)
+ val = np.array([1.0])
+
+ # Sparse tensors test
+ sp_row = np.array([0, 3, 1, 0, 3])
+ sp_col = np.array([0, 3, 1, 2, 2])
+ sp_data = np.array([4, 5, 7, 9, 0])
+
+ lst_tot = []
+
+ for nx in [ot.backend.NumpyBackend(), nx]:
+
+ print('Backend: ', nx.__name__)
+
+ lst_b = []
+ lst_name = []
+
+ Mb = nx.from_numpy(M)
+ vb = nx.from_numpy(v)
+
+ val = nx.from_numpy(val)
+
+ sp_rowb = nx.from_numpy(sp_row)
+ sp_colb = nx.from_numpy(sp_col)
+ sp_datab = nx.from_numpy(sp_data)
+
+ A = nx.set_gradients(val, v, v)
+
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('set_gradients')
+
+ A = nx.zeros((10, 3))
+ A = nx.zeros((10, 3), type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('zeros')
+
+ A = nx.ones((10, 3))
+ A = nx.ones((10, 3), type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('ones')
+
+ A = nx.arange(10, 1, 2)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('arange')
+
+ A = nx.full((10, 3), 3.14)
+ A = nx.full((10, 3), 3.14, type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('full')
+
+ A = nx.eye(10, 3)
+ A = nx.eye(10, 3, type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('eye')
+
+ A = nx.sum(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sum')
+
+ A = nx.sum(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sum(axis)')
+
+ A = nx.cumsum(Mb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('cumsum(axis)')
+
+ A = nx.max(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('max')
+
+ A = nx.max(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('max(axis)')
+
+ A = nx.min(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('min')
+
+ A = nx.min(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('min(axis)')
+
+ A = nx.maximum(vb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('maximum')
+
+ A = nx.minimum(vb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('minimum')
+
+ A = nx.abs(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('abs')
+
+ A = nx.log(A)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('log')
+
+ A = nx.exp(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('exp')
+
+ A = nx.sqrt(nx.abs(Mb))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sqrt')
+
+ A = nx.power(Mb, 2)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('power')
+
+ A = nx.dot(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(v,v)')
+
+ A = nx.dot(Mb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(M,v)')
+
+ A = nx.dot(Mb, Mb.T)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(M,M)')
+
+ A = nx.norm(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('norm')
+
+ A = nx.any(vb > 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('any')
+
+ A = nx.isnan(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('isnan')
+
+ A = nx.isinf(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('isinf')
+
+ A = nx.einsum('ij->i', Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('einsum(ij->i)')
+
+ A = nx.einsum('ij,j->i', Mb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('nx.einsum(ij,j->i)')
+
+ A = nx.einsum('ij->i', Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('nx.einsum(ij->i)')
+
+ A = nx.sort(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sort')
+
+ A = nx.argsort(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argsort')
+
+ A = nx.searchsorted(Mb, Mb, 'right')
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('searchsorted')
+
+ A = nx.flip(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('flip')
+
+ A = nx.outer(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('outer')
+
+ A = nx.clip(vb, 0, 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('clip')
+
+ A = nx.repeat(Mb, 0)
+ A = nx.repeat(Mb, 2, -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('repeat')
+
+ A = nx.take_along_axis(vb, nx.arange(3), -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('take_along_axis')
+
+ A = nx.concatenate((Mb, Mb), -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('concatenate')
+
+ A = nx.zero_pad(Mb, len(Mb.shape) * [(3, 3)])
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('zero_pad')
+
+ A = nx.argmax(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argmax')
+
+ A = nx.mean(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('mean')
+
+ A = nx.std(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('std')
+
+ A = nx.linspace(0, 1, 50)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('linspace')
+
+ X, Y = nx.meshgrid(vb, vb)
+ lst_b.append(np.stack([nx.to_numpy(X), nx.to_numpy(Y)]))
+ lst_name.append('meshgrid')
+
+ A = nx.diag(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('diag2D')
+
+ A = nx.diag(vb, 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('diag1D')
+
+ A = nx.unique(nx.from_numpy(np.stack([M, M])))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('unique')
+
+ A = nx.logsumexp(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('logsumexp')
+
+ A = nx.stack([Mb, Mb])
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('stack')
+
+ A = nx.reshape(Mb, (5, 3, 2))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('reshape')
+
+ sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4))
+ nx.todense(Mb)
+ lst_b.append(nx.to_numpy(nx.todense(sp_Mb)))
+ lst_name.append('coo_matrix')
+
+ assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)'
+ assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)'
+
+ A = nx.tocsr(sp_Mb)
+ lst_b.append(nx.to_numpy(nx.todense(A)))
+ lst_name.append('tocsr')
+
+ A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('eliminate_zeros (dense)')
+
+ A = nx.eliminate_zeros(sp_Mb)
+ lst_b.append(nx.to_numpy(nx.todense(A)))
+ lst_name.append('eliminate_zeros (sparse)')
+
+ A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('where')
+
+ A = nx.copy(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('copy')
+
+ assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)'
+ assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)'
+
+ lst_tot.append(lst_b)
+
+ lst_np = lst_tot[0]
+ lst_b = lst_tot[1]
+
+ for a1, a2, name in zip(lst_np, lst_b, lst_name):
+ if not np.allclose(a1, a2):
+ print('Assert fail on: ', name)
+ assert np.allclose(a1, a2, atol=1e-7)
+
+
+def test_random_backends(nx):
+
+ tmp_u = nx.rand()
+
+ assert tmp_u < 1
+
+ tmp_n = nx.randn()
+
+ nx.seed(0)
+ M1 = nx.to_numpy(nx.rand(5, 2))
+ nx.seed(0)
+ M2 = nx.to_numpy(nx.rand(5, 2, type_as=tmp_n))
+
+ assert np.all(M1 >= 0)
+ assert np.all(M1 < 1)
+ assert M1.shape == (5, 2)
+ assert np.allclose(M1, M2)
+
+ nx.seed(0)
+ M1 = nx.to_numpy(nx.randn(5, 2))
+ nx.seed(0)
+ M2 = nx.to_numpy(nx.randn(5, 2, type_as=tmp_u))
+
+ nx.seed(42)
+ v1 = nx.randn()
+ v2 = nx.randn()
+ assert v1 != v2
+
+
+def test_gradients_backends():
+
+ rnd = np.random.RandomState(0)
+ v = rnd.randn(10)
+ c = rnd.randn()
+ e = rnd.randn()
+
+ if torch:
+
+ nx = ot.backend.TorchBackend()
+
+ v2 = torch.tensor(v, requires_grad=True)
+ c2 = torch.tensor(c, requires_grad=True)
+
+ val = c2 * torch.sum(v2 * v2)
+
+ val2 = nx.set_gradients(val, (v2, c2), (v2, c2))
+
+ val2.backward()
+
+ assert torch.equal(v2.grad, v2)
+ assert torch.equal(c2.grad, c2)
+
+ if jax:
+ nx = ot.backend.JaxBackend()
+ with jax.checking_leaks():
+ def fun(a, b, d):
+ val = b * nx.sum(a ** 4) + d
+ return nx.set_gradients(val, (a, b, d), (a, b, 2 * d))
+ grad_val = jax.grad(fun, argnums=(0, 1, 2))(v, c, e)
+
+ np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4)
+ np.testing.assert_allclose(grad_val[0], v, atol=1e-4)
+ np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6aa4e08..830052d 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -2,15 +2,21 @@
# Author: Remi Flamary <remi.flamary@unice.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
+from itertools import product
+
import numpy as np
-import ot
import pytest
+import ot
+from ot.backend import torch
+
-def test_sinkhorn():
+@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
+def test_sinkhorn(verbose, warn):
# test sinkhorn
n = 100
rng = np.random.RandomState(0)
@@ -20,14 +26,189 @@ def test_sinkhorn():
M = ot.dist(x, x)
- G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
+ G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+ with pytest.warns(UserWarning):
+ ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_epsilon_scaling",
+ "greenkhorn",
+ "sinkhorn_log"])
+def test_convergence_warning(method):
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+ A = np.asarray([a1, a2]).T
+ M = ot.utils.dist0(n)
+
+ with pytest.warns(UserWarning):
+ ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1)
+
+ if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]:
+ with pytest.warns(UserWarning):
+ ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
+
+
+def test_not_impemented_method():
+ # test sinkhorn
+ w = 10
+ n = w ** 2
+ rng = np.random.RandomState(42)
+ A_img = rng.rand(2, w, w)
+ A_flat = A_img.reshape(n, 2)
+ a1, a2 = A_flat.T
+ M_flat = ot.utils.dist0(n)
+ not_implemented = "new_method"
+ reg = 0.01
+ with pytest.raises(ValueError):
+ ot.sinkhorn(a1, a2, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.sinkhorn2(a1, a2, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.barycenter(A_flat, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.barycenter_debiased(A_flat, M_flat, reg,
+ method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.convolutional_barycenter2d(A_img, reg,
+ method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.convolutional_barycenter2d_debiased(A_img, reg,
+ method=not_implemented)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_nan_warning(method):
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+
+ M = ot.utils.dist0(n)
+ reg = 0
+ with pytest.warns(UserWarning):
+ # warn set to False to avoid catching a convergence warning instead
+ ot.sinkhorn(a1, a2, M, reg, method=method, warn=False)
+
+
+def test_sinkhorn_stabilization():
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+ M = ot.utils.dist0(n)
+ reg = 1e-5
+ loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log")
+ loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized")
+ np.testing.assert_allclose(
+ loss1, loss2, atol=1e-06) # cf convergence sinkhorn
+
+
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_sinkhorn_multi_b(method, verbose, warn):
+ # test sinkhorn
+ n = 10
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10,
+ log=True)
+
+ loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10,
+ verbose=verbose, warn=warn) for k in range(3)]
+ # check constraints
+ np.testing.assert_allclose(
+ loss0, loss, atol=1e-4) # cf convergence sinkhorn
+
+
+def test_sinkhorn_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ M_nx = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn(ab, ab, M_nx, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+def test_sinkhorn2_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ M_nx = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn2(ab, ab, M_nx, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+def test_sinkhorn2_gradients():
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ if torch:
+
+ a1 = torch.tensor(a, requires_grad=True)
+ b1 = torch.tensor(a, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.sinkhorn2(a1, b1, M1, 1)
+
+ val.backward()
+
+ assert a1.shape == a1.grad.shape
+ assert b1.shape == b1.grad.shape
+ assert M1.shape == M1.grad.shape
+
def test_sinkhorn_empty():
# test sinkhorn
@@ -39,21 +220,27 @@ def test_sinkhorn_empty():
M = ot.dist(x, x)
+ G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log",
+ verbose=True, log=True)
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
+
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10,
method='sinkhorn_stabilized', verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
G, log = ot.sinkhorn(
[], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling',
verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
@@ -61,7 +248,8 @@ def test_sinkhorn_empty():
ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
-def test_sinkhorn_variants():
+@pytest.skip_backend("jax")
+def test_sinkhorn_variants(nx):
# test sinkhorn
n = 100
rng = np.random.RandomState(0)
@@ -71,22 +259,131 @@ def test_sinkhorn_variants():
M = ot.dist(x, x)
- G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
- Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
- Ges = ot.sinkhorn(
- u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
- G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
+ ub = nx.from_numpy(u)
+ M_nx = nx.from_numpy(M)
+
+ G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Ges = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10))
+ G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10))
# check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
- print(G0, G_green)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_epsilon_scaling",
+ "greenkhorn",
+ "sinkhorn_log"])
+@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str)
+@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str)
+def test_sinkhorn_variants_dtype_device(nx, method):
+ n = 100
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ ub = nx.from_numpy(u, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
+
+ Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+
+ nx.assert_same_dtype_device(Mb, Gb)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"])
+def test_sinkhorn2_variants_dtype_device(nx, method):
+ n = 100
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ ub = nx.from_numpy(u, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
+
+ lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+
+ nx.assert_same_dtype_device(Mb, lossb)
+
+
+@pytest.skip_backend("jax")
+def test_sinkhorn_variants_multi_b(nx):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ ub = nx.from_numpy(u)
+ bb = nx.from_numpy(b)
+ M_nx = nx.from_numpy(M)
+
+ G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+
+ # check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+
+
+@pytest.skip_backend("jax")
+def test_sinkhorn2_variants_multi_b(nx):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ ub = nx.from_numpy(u)
+ bb = nx.from_numpy(b)
+ M_nx = nx.from_numpy(M)
+
+ G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+
+ # check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
def test_sinkhorn_variants_log():
# test sinkhorn
- n = 100
+ n = 50
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -95,20 +392,87 @@ def test_sinkhorn_variants_log():
M = ot.dist(x, x)
G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
- u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
+ u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,)
G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Gl, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
- print(G0, G_green)
-@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_barycenter(method):
+@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
+def test_sinkhorn_variants_log_multib(verbose, warn):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True,
+ verbose=verbose, warn=warn)
+ Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True,
+ verbose=verbose, warn=warn)
+
+ # check values
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Gl, atol=1e-05)
+
+
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_barycenter(nx, method, verbose, warn):
+ n_bins = 100 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
+ weights_nx = nx.from_numpy(weights)
+ reg = 1e-2
+
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method)
+ else:
+ # wasserstein
+ bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(bary_wass, bary_wass_np)
+
+ ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
+
+
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_barycenter_debiased(nx, method, verbose, warn):
n_bins = 100 # nb bins
# Gaussian distributions
@@ -125,16 +489,61 @@ def test_barycenter(method):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
+ weights_nx = nx.from_numpy(weights)
+
# wasserstein
reg = 1e-2
- bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method)
+ else:
+ bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method,
+ verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5)
- np.testing.assert_allclose(1, np.sum(bary_wass))
+ ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False)
- ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_convergence_warning_barycenters(method):
+ w = 10
+ n_bins = w ** 2 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+ A_img = A.reshape(2, w, w)
+ A_img /= A_img.sum((1, 2))[:, None, None]
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
-def test_barycenter_stabilization():
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+ reg = 0.1
+ with pytest.warns(UserWarning):
+ ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.convolutional_barycenter2d(A_img, reg, weights,
+ method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights,
+ method=method, numItermax=1)
+
+
+def test_barycenter_stabilization(nx):
n_bins = 100 # nb bins
# Gaussian distributions
@@ -151,22 +560,64 @@ def test_barycenter_stabilization():
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
+ weights_b = nx.from_numpy(weights)
+
# wasserstein
reg = 1e-2
- bar_stable = ot.bregman.barycenter(A, M, reg, weights,
- method="sinkhorn_stabilized",
- stopThr=1e-8, verbose=True)
- bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
- stopThr=1e-8, verbose=True)
+ bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
+ bar_stable = nx.to_numpy(ot.bregman.barycenter(
+ A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized",
+ stopThr=1e-8, verbose=True
+ ))
+ bar = nx.to_numpy(ot.bregman.barycenter(
+ A_nx, M_nx, reg, weights_b, method="sinkhorn",
+ stopThr=1e-8, verbose=True
+ ))
np.testing.assert_allclose(bar, bar_stable)
+ np.testing.assert_allclose(bar, bar_np)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_wasserstein_bary_2d(nx, method):
+ size = 20 # size of a square image
+ a1 = np.random.rand(size, size)
+ a1 += a1.min()
+ a1 = a1 / np.sum(a1)
+ a2 = np.random.rand(size, size)
+ a2 += a2.min()
+ a2 = a2 / np.sum(a2)
+ # creating matrix A containing all distributions
+ A = np.zeros((2, size, size))
+ A[0, :, :] = a1
+ A[1, :, :] = a2
+ A_nx = nx.from_numpy(A)
-def test_wasserstein_bary_2d():
- size = 100 # size of a square image
- a1 = np.random.randn(size, size)
+ # wasserstein
+ reg = 1e-2
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
+ else:
+ bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method)
+ bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
+
+ np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+
+ # help in checking if log and verbose do not bug the function
+ ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_wasserstein_bary_2d_debiased(nx, method):
+ size = 20 # size of a square image
+ a1 = np.random.rand(size, size)
a1 += a1.min()
a1 = a1 / np.sum(a1)
- a2 = np.random.randn(size, size)
+ a2 = np.random.rand(size, size)
a2 += a2.min()
a2 = a2 / np.sum(a2)
# creating matrix A containing all distributions
@@ -174,17 +625,25 @@ def test_wasserstein_bary_2d():
A[0, :, :] = a1
A[1, :, :] = a2
+ A_nx = nx.from_numpy(A)
+
# wasserstein
reg = 1e-2
- bary_wass = ot.bregman.convolutional_barycenter2d(A, reg)
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
+ else:
+ bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method)
+ bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
- np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
- # help in checking if log and verbose do not bug the function
- ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+ # help in checking if log and verbose do not bug the function
+ ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
-def test_unmix():
+def test_unmix(nx):
n_bins = 50 # nb bins
# Gaussian distributions
@@ -204,41 +663,58 @@ def test_unmix():
M0 /= M0.max()
h0 = ot.unif(2)
+ ab = nx.from_numpy(a)
+ Db = nx.from_numpy(D)
+ M_nx = nx.from_numpy(M)
+ M0b = nx.from_numpy(M0)
+ h0b = nx.from_numpy(h0)
+
# wasserstein
reg = 1e-3
- um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, )
+ um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01)
+ um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01))
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
+ np.testing.assert_allclose(um, um_np)
- ot.bregman.unmix(a, D, M, M0, h0, reg,
+ ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg,
1, alpha=0.01, log=True, verbose=True)
-def test_empirical_sinkhorn():
+def test_empirical_sinkhorn(nx):
# test sinkhorn
- n = 100
+ n = 10
a = ot.unif(n)
b = ot.unif(n)
- X_s = np.reshape(np.arange(n), (n, 1))
- X_t = np.reshape(np.arange(0, n), (n, 1))
+ X_s = np.reshape(1.0 * np.arange(n), (n, 1))
+ X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))
M = ot.dist(X_s, X_t)
- M_m = ot.dist(X_s, X_t, metric='minkowski')
+ M_m = ot.dist(X_s, X_t, metric='euclidean')
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ M_nx = nx.from_numpy(M, type_as=ab)
+ M_mb = nx.from_numpy(M_m, type_as=ab)
- G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
- sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+ G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1))
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
- G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True)
- sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+ G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True)
+ G_log = nx.to_numpy(G_log)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
+ sinkhorn_log = nx.to_numpy(sinkhorn_log)
- G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski')
- sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
+ G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean'))
+ sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1)
- loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+ loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
- # check constratints
+ # check constraints
np.testing.assert_allclose(
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
np.testing.assert_allclose(
@@ -254,34 +730,98 @@ def test_empirical_sinkhorn():
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
-def test_empirical_sinkhorn_divergence():
- # Test sinkhorn divergence
+def test_lazy_empirical_sinkhorn(nx):
+ # test sinkhorn
n = 10
a = ot.unif(n)
b = ot.unif(n)
+ numIterMax = 1000
+
+ X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1))
+ X_t = np.reshape(np.arange(0, n, dtype=np.float64), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_m = ot.dist(X_s, X_t, metric='euclidean')
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ M_nx = nx.from_numpy(M, type_as=ab)
+ M_mb = nx.from_numpy(M_m, type_as=ab)
+
+ f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
+ G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
+
+ f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
+ G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
+ sinkhorn_log = nx.to_numpy(sinkhorn_log)
+
+ f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
+ G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
+ sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
+
+ loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
+
+ # check constraints
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+
+
+def test_empirical_sinkhorn_divergence(nx):
+ # Test sinkhorn divergence
+ n = 10
+ a = np.linspace(1, n, n)
+ a /= a.sum()
+ b = ot.unif(n)
X_s = np.reshape(np.arange(n), (n, 1))
X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
M = ot.dist(X_s, X_t)
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1)
- sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1))
-
- emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True)
- sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True)
- sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True)
- sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True)
- sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b)
-
- # check constratints
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ M_nx = nx.from_numpy(M, type_as=ab)
+ M_sb = nx.from_numpy(M_s, type_as=ab)
+ M_tb = nx.from_numpy(M_t, type_as=ab)
+
+ emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
+ sinkhorn_div = nx.to_numpy(
+ ot.sinkhorn2(ab, bb, M_nx, 1)
+ - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1)
+ - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1)
+ )
+ emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b)
+
+ # check constraints
+ np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05)
np.testing.assert_allclose(
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
- np.testing.assert_allclose(
- emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn
+
+ ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True)
-def test_stabilized_vs_sinkhorn_multidim():
+def test_stabilized_vs_sinkhorn_multidim(nx):
# test if stable version matches sinkhorn
# for multidimensional inputs
n = 100
@@ -297,12 +837,21 @@ def test_stabilized_vs_sinkhorn_multidim():
M = ot.utils.dist0(n)
M /= np.median(M)
epsilon = 0.1
- G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon,
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ M_nx = nx.from_numpy(M, type_as=ab)
+
+ G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
+ G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon,
method="sinkhorn_stabilized",
log=True)
- G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon,
+ G = nx.to_numpy(G)
+ G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon,
method="sinkhorn", log=True)
+ G2 = nx.to_numpy(G2)
+ np.testing.assert_allclose(G_np, G2)
np.testing.assert_allclose(G, G2)
@@ -320,8 +869,9 @@ def test_implemented_methods():
# make dists unbalanced
b = ot.utils.unif(n)
A = rng.rand(n, 2)
+ A /= A.sum(0, keepdims=True)
M = ot.dist(x, x)
- epsilon = 1.
+ epsilon = 1.0
for method in IMPLEMENTED_METHODS:
ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
@@ -338,7 +888,9 @@ def test_implemented_methods():
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
-def test_screenkhorn():
+@pytest.skip_backend("jax")
+@pytest.mark.filterwarnings("ignore:Bottleneck")
+def test_screenkhorn(nx):
# test screenkhorn
rng = np.random.RandomState(0)
n = 100
@@ -347,17 +899,31 @@ def test_screenkhorn():
x = rng.randn(n, 2)
M = ot.dist(x, x)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ M_nx = nx.from_numpy(M, type_as=ab)
+
+ # np sinkhorn
+ G_sink_np = ot.sinkhorn(a, b, M, 1e-03)
# sinkhorn
- G_sink = ot.sinkhorn(a, b, M, 1e-03)
+ G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03))
# screenkhorn
- G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True)
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True))
# check marginals
+ np.testing.assert_allclose(G_sink_np, G_sink)
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
-def test_convolutional_barycenter_non_square():
+def test_convolutional_barycenter_non_square(nx):
# test for image with height not equal width
A = np.ones((2, 2, 3)) / (2 * 3)
- b = ot.bregman.convolutional_barycenter2d(A, 1e-03)
+ A_nx = nx.from_numpy(A)
+
+ b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03)
+ b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, 1e-03))
+
+ np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
+ np.testing.assert_allclose(b, b_np)
diff --git a/test/test_da.py b/test/test_da.py
index 3b28119..9f2bb50 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -6,11 +6,18 @@
import numpy as np
from numpy.testing import assert_allclose, assert_equal
+import pytest
import ot
from ot.datasets import make_data_classif
from ot.utils import unif
+try: # test if cudamat installed
+ import sklearn # noqa: F401
+ nosklearn = False
+except ImportError:
+ nosklearn = True
+
def test_sinkhorn_lpl1_transport_class():
"""test_sinkhorn_transport
@@ -99,8 +106,8 @@ def test_sinkhorn_l1l2_transport_class():
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 100
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -441,8 +448,8 @@ def test_mapping_transport_class():
"""test_mapping_transport
"""
- ns = 60
- nt = 120
+ ns = 20
+ nt = 30
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -558,6 +565,14 @@ def test_mapping_transport_class():
otda.fit(Xs=Xs, Xt=Xt)
assert len(otda.log_.keys()) != 0
+ # check that it does not crash when derphi is very close to 0
+ np.random.seed(39)
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+ otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
+ otda.fit(Xs=Xs, Xt=Xt)
+ np.random.seed(None)
+
def test_linear_mapping():
ns = 150
@@ -691,6 +706,7 @@ def test_jcpot_barycenter():
np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3)
+@pytest.mark.skipif(nosklearn, reason="No sklearn available")
def test_emd_laplace_class():
"""test_emd_laplace_transport
"""
diff --git a/test/test_dr.py b/test/test_dr.py
index c5df287..741f2ad 100644
--- a/test/test_dr.py
+++ b/test/test_dr.py
@@ -1,6 +1,7 @@
"""Tests for module dr on Dimensionality Reduction """
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Minhui Huang <mhhuang@ucdavis.edu>
#
# License: MIT License
@@ -57,3 +58,64 @@ def test_wda():
projwda(xs)
np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
+
+
+@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
+def test_wda_normalized():
+
+ n_samples = 100 # nb samples in source and target datasets
+ np.random.seed(0)
+
+ # generate gaussian dataset
+ xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples)
+
+ n_features_noise = 8
+
+ xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise)))
+
+ p = 2
+
+ P0 = np.random.randn(10, p)
+ P0 /= P0.sum(0, keepdims=True)
+
+ Pwda, projwda = ot.dr.wda(xs, ys, p, maxiter=10, P0=P0, normalize=True)
+
+ projwda(xs)
+
+ np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
+
+
+@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
+def test_prw():
+ d = 100 # Dimension
+ n = 100 # Number samples
+ k = 3 # Subspace dimension
+ dim = 3
+
+ def fragmented_hypercube(n, d, dim):
+ assert dim <= d
+ assert dim >= 1
+ assert dim == int(dim)
+
+ a = (1. / n) * np.ones(n)
+ b = (1. / n) * np.ones(n)
+
+ # First measure : uniform on the hypercube
+ X = np.random.uniform(-1, 1, size=(n, d))
+
+ # Second measure : fragmentation
+ tmp_y = np.random.uniform(-1, 1, size=(n, d))
+ Y = tmp_y + 2 * np.sign(tmp_y) * np.array(dim * [1] + (d - dim) * [0])
+ return a, b, X, Y
+
+ a, b, X, Y = fragmented_hypercube(n, d, dim)
+
+ tau = 0.002
+ reg = 0.2
+
+ pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, reg=reg, k=k, maxiter=1000, verbose=1)
+
+ U0 = np.random.randn(d, k)
+ U0, _ = np.linalg.qr(U0)
+
+ pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, U0=U0, reg=reg, k=k, maxiter=1000, verbose=1)
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 43da9fc..c4bc04c 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -8,9 +8,13 @@
import numpy as np
import ot
+from ot.backend import NumpyBackend
+from ot.backend import torch
+import pytest
-def test_gromov():
+
+def test_gromov(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -29,37 +33,121 @@ def test_gromov():
C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True)
+ Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True))
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
- np.testing.assert_allclose(
- G, np.flipud(Id), atol=1e-04)
+ np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04)
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
- np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
+ np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1)
- np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False
+ np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06)
+ np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_gromov_dtype_device(nx):
+ # setup
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ C1b = nx.from_numpy(C1, type_as=tp)
+ C2b = nx.from_numpy(C2, type_as=tp)
+ pb = nx.from_numpy(p, type_as=tp)
+ qb = nx.from_numpy(q, type_as=tp)
+
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
+def test_gromov2_gradients():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ p1 = torch.tensor(p, requires_grad=True)
+ q1 = torch.tensor(q, requires_grad=True)
+ C11 = torch.tensor(C1, requires_grad=True)
+ C12 = torch.tensor(C2, requires_grad=True)
+
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1)
+
+ val.backward()
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
-def test_entropic_gromov():
+
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_entropic_gromov(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -78,85 +166,278 @@ def test_entropic_gromov():
C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
G = ot.gromov.entropic_gromov_wasserstein(
C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
+ Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ ))
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
gw, log = ot.gromov.entropic_gromov_wasserstein2(
C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+ gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
-def test_gromov_barycenter():
- ns = 50
- nt = 60
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_entropic_gromov_dtype_device(nx):
+ # setup
+ n_samples = 50 # nb samples
- Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
- Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
- C1 = ot.dist(Xs)
- C2 = ot.dist(Xt)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
- n_samples = 3
- Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'square_loss', # 5e-4,
- max_iter=100, tol=1e-3,
- verbose=True)
- np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ xt = xs[::-1].copy()
- Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'kl_loss', # 5e-4,
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
-def test_gromov_entropic_barycenter():
- ns = 50
- nt = 60
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ C1b = nx.from_numpy(C1, type_as=tp)
+ C2b = nx.from_numpy(C2, type_as=tp)
+ pb = nx.from_numpy(p, type_as=tp)
+ qb = nx.from_numpy(q, type_as=tp)
+
+ Gb = ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ )
+ gw_valb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ )
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
+def test_pointwise_gromov(nx):
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
+ def loss(x, y):
+ return np.abs(x - y)
+
+ def lossb(x, y):
+ return nx.abs(x - y)
+
+ G, log = ot.gromov.pointwise_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42)
+ G = NumpyBackend().todense(G)
+ Gb, logb = ot.gromov.pointwise_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(nx.todense(Gb))
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08)
+
+ G, log = ot.gromov.pointwise_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
+ G = NumpyBackend().todense(G)
+ Gb, logb = ot.gromov.pointwise_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(nx.todense(Gb))
+
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8)
+
+
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_sampled_gromov(nx):
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0], dtype=np.float64)
+ cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64)
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
+ def loss(x, y):
+ return np.abs(x - y)
+
+ def lossb(x, y):
+ return nx.abs(x - y)
+
+ G, log = ot.gromov.sampled_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ Gb, logb = ot.gromov.sampled_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(Gb)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08)
+
+
+def test_gromov_barycenter(nx):
+ ns = 10
+ nt = 20
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
-
+ p1 = ot.unif(ns)
+ p2 = ot.unif(nt)
n_samples = 3
- Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'square_loss', 2e-3,
- max_iter=100, tol=1e-3,
- verbose=True)
- np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ p = ot.unif(n_samples)
- Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'kl_loss', 2e-3,
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Cb = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ )
+ Cbb = nx.to_numpy(ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ ))
+ np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
+ np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, random_state=42
+ )
+ Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, random_state=42
+ ))
+ np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
+ np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+
+
+@pytest.mark.filterwarnings("ignore:divide")
+def test_gromov_entropic_barycenter(nx):
+ ns = 10
+ nt = 20
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
-def test_fgw():
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+ p1 = ot.unif(ns)
+ p2 = ot.unif(nt)
+ n_samples = 2
+ p = ot.unif(n_samples)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Cb = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42
+ )
+ Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42
+ ))
+ np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
+ np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
+ )
+ Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
+ ))
+ np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
+ np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+
+
+def test_fgw(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -181,33 +462,85 @@ def test_fgw():
M = ot.dist(ys, yt)
M /= M.max()
- G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5)
+ Mb = nx.from_numpy(M)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
- # check constratints
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ Gb = nx.to_numpy(Gb)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence fgw
+ p, Gb.sum(1), atol=1e-04) # cf convergence fgw
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence fgw
+ q, Gb.sum(0), atol=1e-04) # cf convergence fgw
Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
np.testing.assert_allclose(
- G, np.flipud(Id), atol=1e-04) # cf convergence gromov
+ Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov
fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
- np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(fgw, fgwb, atol=1e-08)
+ np.testing.assert_allclose(fgwb, 0, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_fgw2_gradients():
+ n_samples = 50 # nb samples
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
-def test_fgw_barycenter():
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ M = ot.dist(xs, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ p1 = torch.tensor(p, requires_grad=True)
+ q1 = torch.tensor(q, requires_grad=True)
+ C11 = torch.tensor(C1, requires_grad=True)
+ C12 = torch.tensor(C2, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
+
+ val.backward()
+
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
+
+
+def test_fgw_barycenter(nx):
np.random.seed(42)
ns = 50
@@ -221,30 +554,44 @@ def test_fgw_barycenter():
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
-
+ p1, p2 = ot.unif(ns), ot.unif(nt)
n_samples = 3
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ p = ot.unif(n_samples)
+
+ ysb = nx.from_numpy(ys)
+ ytb = nx.from_numpy(yt)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Xb, Cb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False,
+ fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345
+ )
xalea = np.random.randn(n_samples, 2)
init_C = ot.dist(xalea, xalea)
-
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5,
- fixed_structure=True, init_C=init_C, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ init_Cb = nx.from_numpy(init_C)
+
+ Xb, Cb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=[.5, .5],
+ alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False,
+ p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3
+ )
+ Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
init_X = np.random.randn(n_samples, ys.shape[1])
-
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=True, init_X=init_X,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ init_Xb = nx.from_numpy(init_X)
+
+ Xb, Cb, logb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_Xb,
+ p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, log=True, random_state=98765
+ )
+ Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
diff --git a/test/test_helpers.py b/test/test_helpers.py
new file mode 100644
index 0000000..cc4c90e
--- /dev/null
+++ b/test/test_helpers.py
@@ -0,0 +1,26 @@
+"""Tests for helpers functions """
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+import os
+import sys
+
+sys.path.append(os.path.join("ot", "helpers"))
+
+from openmp_helpers import get_openmp_flag, check_openmp_support # noqa
+from pre_build_helpers import _get_compiler, compile_test_program # noqa
+
+
+def test_helpers():
+
+ compiler = _get_compiler()
+
+ get_openmp_flag(compiler)
+
+ s = '#include <stdio.h>\n#include <stdlib.h>\n\nint main(void) {\n\tprintf("Hello world!\\n");\n\treturn 0;\n}'
+ output, _ = compile_test_program(s)
+ assert len(output) == 1 and output[0] == "Hello world!"
+
+ check_openmp_support()
diff --git a/test/test_optim.py b/test/test_optim.py
index 87b0268..4efd9b1 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -8,7 +8,7 @@ import numpy as np
import ot
-def test_conditional_gradient():
+def test_conditional_gradient(nx):
n_bins = 100 # nb bins
np.random.seed(0)
@@ -29,16 +29,26 @@ def test_conditional_gradient():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
reg = 1e-1
G, log = ot.optim.cg(a, b, M, reg, f, df, verbose=True, log=True)
+ Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1))
- np.testing.assert_allclose(b, G.sum(0))
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1))
+ np.testing.assert_allclose(b, Gb.sum(0))
-def test_conditional_gradient2():
- n = 1000 # nb samples
+def test_conditional_gradient_itermax(nx):
+ n = 100 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -61,16 +71,27 @@ def test_conditional_gradient2():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
reg = 1e-1
- G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000,
+ G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000,
verbose=True, log=True)
+ Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, numItermaxEmd=10000,
+ verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1))
- np.testing.assert_allclose(b, G.sum(0))
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1))
+ np.testing.assert_allclose(b, Gb.sum(0))
-def test_generalized_conditional_gradient():
+def test_generalized_conditional_gradient(nx):
n_bins = 100 # nb bins
np.random.seed(0)
@@ -91,16 +112,76 @@ def test_generalized_conditional_gradient():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
reg1 = 1e-3
reg2 = 1e-1
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True)
+ Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
- np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05)
def test_solve_1d_linesearch_quad_funct():
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5)
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0)
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)
+
+
+def test_line_search_armijo(nx):
+ xk = np.array([[0.25, 0.25], [0.25, 0.25]])
+ pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
+ gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
+ old_fval = -123
+ # Should not throw an exception and return None for alpha
+ alpha, a, b = ot.optim.line_search_armijo(
+ lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval
+ )
+ alpha_np, anp, bnp = ot.optim.line_search_armijo(
+ lambda x: 1, xk, pk, gfk, old_fval
+ )
+ assert a == anp
+ assert b == bnp
+ assert alpha is None
+
+ # check line search armijo
+ def f(x):
+ return nx.sum((x - 5.0) ** 2)
+
+ def grad(x):
+ return 2 * (x - 5.0)
+
+ xk = nx.from_numpy(np.array([[[-5.0, -5.0]]]))
+ pk = nx.from_numpy(np.array([[[100.0, 100.0]]]))
+ gfk = grad(xk)
+ old_fval = f(xk)
+
+ # chech the case where the optimum is on the direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
+ np.testing.assert_allclose(alpha, 0.1)
+
+ # check the case where the direction is not far enough
+ pk = nx.from_numpy(np.array([[[3.0, 3.0]]]))
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0)
+ np.testing.assert_allclose(alpha, 1.0)
+
+ # check the case where checking the wrong direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval)
+ assert alpha <= 0
+
+ # check the case where the point is not a vector
+ xk = nx.from_numpy(np.array(-5.0))
+ pk = nx.from_numpy(np.array(100.0))
+ gfk = grad(xk)
+ old_fval = f(xk)
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
+ np.testing.assert_allclose(alpha, 0.1)
diff --git a/test/test_ot.py b/test/test_ot.py
index b7306f6..92f26a7 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -8,13 +8,13 @@ import warnings
import numpy as np
import pytest
-from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
+from ot.backend import torch
-def test_emd_dimension_mismatch():
+def test_emd_dimension_and_mass_mismatch():
# test emd and emd2 for dimension mismatch
n_samples = 100
n_features = 2
@@ -29,122 +29,125 @@ def test_emd_dimension_mismatch():
np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
+ b = a.copy()
+ a[0] = 100
+ np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
-def test_emd_emd2():
- # test emd and emd2 for simple identity
- n = 100
+
+def test_emd_backends(nx):
+ n_samples = 100
+ n_features = 2
rng = np.random.RandomState(0)
- x = rng.randn(n, 2)
- u = ot.utils.unif(n)
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
- M = ot.dist(x, x)
+ M = ot.dist(x, y)
- G = ot.emd(u, u, M)
+ G = ot.emd(a, a, M)
- # check G is identity
- np.testing.assert_allclose(G, np.eye(n) / n)
- # check constraints
- np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
- np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
- w = ot.emd2(u, u, M)
- # check loss=0
- np.testing.assert_allclose(w, 0)
+ Gb = ot.emd(ab, ab, Mb)
+
+ np.allclose(G, nx.to_numpy(Gb))
-def test_emd_1d_emd2_1d():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
+def test_emd2_backends(nx):
+ n_samples = 100
+ n_features = 2
rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
- M = ot.dist(u, v, metric='sqeuclidean')
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
- G, log = ot.emd([], [], M, log=True)
- wass = log["cost"]
- G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
- wass1d = log["cost"]
- wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
- wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
+ val = ot.emd2(a, a, M)
- # check loss is similar
- np.testing.assert_allclose(wass, wass1d)
- np.testing.assert_allclose(wass, wass1d_emd2)
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
- # check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
- np.testing.assert_allclose(wass_sp, wass1d_euc)
+ valb = ot.emd2(ab, ab, Mb)
- # check constraints
- np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
- np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
+ np.allclose(val, nx.to_numpy(valb))
+
+
+def test_emd_emd2_types_devices(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ ab = nx.from_numpy(a, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
- # check G is similar
- np.testing.assert_allclose(G, G_1d)
+ Gb = ot.emd(ab, ab, Mb)
- # check AssertionError is raised if called on non 1d arrays
- u = np.random.randn(n, 2)
- v = np.random.randn(m, 2)
- with pytest.raises(AssertionError):
- ot.emd_1d(u, v, [], [])
+ w = ot.emd2(ab, ab, Mb)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
-def test_emd_1d_emd2_1d_with_weights():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
+
+def test_emd2_gradients():
+ n_samples = 100
+ n_features = 2
rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
- w_u = rng.uniform(0., 1., n)
- w_u = w_u / w_u.sum()
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
- w_v = rng.uniform(0., 1., m)
- w_v = w_v / w_v.sum()
+ M = ot.dist(x, y)
- M = ot.dist(u, v, metric='sqeuclidean')
+ if torch:
- G, log = ot.emd(w_u, w_v, M, log=True)
- wass = log["cost"]
- G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
- wass1d = log["cost"]
- wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
- wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
+ a1 = torch.tensor(a, requires_grad=True)
+ b1 = torch.tensor(a, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
- # check loss is similar
- np.testing.assert_allclose(wass, wass1d)
- np.testing.assert_allclose(wass, wass1d_emd2)
+ val = ot.emd2(a1, b1, M1)
- # check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
- np.testing.assert_allclose(wass_sp, wass1d_euc)
+ val.backward()
- # check constraints
- np.testing.assert_allclose(w_u, G.sum(1))
- np.testing.assert_allclose(w_v, G.sum(0))
+ assert a1.shape == a1.grad.shape
+ assert b1.shape == b1.grad.shape
+ assert M1.shape == M1.grad.shape
-def test_wass_1d():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
+def test_emd_emd2():
+ # test emd and emd2 for simple identity
+ n = 100
rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
- M = ot.dist(u, v, metric='sqeuclidean')
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
- G, log = ot.emd([], [], M, log=True)
- wass = log["cost"]
+ M = ot.dist(x, x)
- wass1d = ot.wasserstein_1d(u, v, [], [], p=2.)
+ G = ot.emd(u, u, M)
+
+ # check G is identity
+ np.testing.assert_allclose(G, np.eye(n) / n)
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
+ np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
- # check loss is similar
- np.testing.assert_allclose(np.sqrt(wass), wass1d)
+ w = ot.emd2(u, u, M)
+ # check loss=0
+ np.testing.assert_allclose(w, 0)
def test_emd_empty():
@@ -291,17 +294,7 @@ def test_warnings():
print('Computing {} EMD '.format(1))
ot.emd(a, b, M, numItermax=1)
assert "numItermax" in str(w[-1].message)
- assert len(w) == 1
- a[0] = 100
- print('Computing {} EMD '.format(2))
- ot.emd(a, b, M)
- assert "infeasible" in str(w[-1].message)
- assert len(w) == 2
- a[0] = -1
- print('Computing {} EMD '.format(2))
- ot.emd(a, b, M)
- assert "infeasible" in str(w[-1].message)
- assert len(w) == 3
+ #assert len(w) == 1
def test_dual_variables():
diff --git a/test/test_partial.py b/test/test_partial.py
index 510e081..97c611b 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -51,10 +51,12 @@ def test_raise_errors():
ot.partial.partial_gromov_wasserstein(M, M, p, q, m=-1, log=True)
with pytest.raises(ValueError):
- ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, log=True)
+ ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2,
+ log=True)
with pytest.raises(ValueError):
- ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, log=True)
+ ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1,
+ log=True)
def test_partial_wasserstein_lagrange():
@@ -102,7 +104,7 @@ def test_partial_wasserstein():
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
log=True, verbose=True)
- # check constratints
+ # check constraints
np.testing.assert_equal(
w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
@@ -125,11 +127,11 @@ def test_partial_wasserstein():
np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
np.testing.assert_equal(
- G.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
+ G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
- G.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
+ G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
np.testing.assert_allclose(
np.sum(G), m, atol=1e-04)
@@ -192,7 +194,7 @@ def test_partial_gromov_wasserstein():
100, m=m,
log=True)
- # check constratints
+ # check constraints
np.testing.assert_equal(
res0.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
diff --git a/test/test_regpath.py b/test/test_regpath.py
new file mode 100644
index 0000000..967c27b
--- /dev/null
+++ b/test/test_regpath.py
@@ -0,0 +1,64 @@
+"""Tests for module regularization path"""
+
+# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+
+
+def test_fully_relaxed_path():
+
+ n_source = 50 # nb source samples (gaussian)
+ n_target = 40 # nb target samples (gaussian)
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ np.random.seed(0)
+ xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov)
+
+ # source and target distributions
+ a = ot.utils.unif(n_source)
+ b = ot.utils.unif(n_target)
+
+ # loss matrix
+ M = ot.dist(xs, xt)
+ M /= M.max()
+
+ t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8,
+ semi_relaxed=False)
+
+ G = t.reshape((n_source, n_target))
+ np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
+
+
+def test_semi_relaxed_path():
+
+ n_source = 50 # nb source samples (gaussian)
+ n_target = 40 # nb target samples (gaussian)
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ np.random.seed(0)
+ xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov)
+
+ # source and target distributions
+ a = ot.utils.unif(n_source)
+ b = ot.utils.unif(n_target)
+
+ # loss matrix
+ M = ot.dist(xs, xt)
+ M /= M.max()
+
+ t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8,
+ semi_relaxed=True)
+
+ G = t.reshape((n_source, n_target))
+ np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, G.sum(0), atol=1e-10)
diff --git a/test/test_sliced.py b/test/test_sliced.py
new file mode 100644
index 0000000..245202c
--- /dev/null
+++ b/test/test_sliced.py
@@ -0,0 +1,213 @@
+"""Tests for module sliced"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import pytest
+
+import ot
+from ot.sliced import get_random_projections
+
+
+def test_get_random_projections():
+ rng = np.random.RandomState(0)
+ projections = get_random_projections(1000, 50, rng)
+ np.testing.assert_almost_equal(np.sum(projections ** 2, 0), 1.)
+
+
+def test_sliced_same_dist():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng)
+ np.testing.assert_almost_equal(res, 0.)
+
+
+def test_sliced_bad_shapes():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(n, 4)
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ y = rng.randn(n, 4)
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, p=1, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert projections.shape[1] == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_different_dists():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ y = rng.randn(n, 2)
+
+ res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)
+ assert res > 0.
+
+
+def test_1d_sliced_equals_emd():
+ n = 100
+ m = 120
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 1)
+ a = rng.uniform(0, 1, n)
+ a /= a.sum()
+ y = rng.randn(m, 1)
+ u = ot.utils.unif(m)
+ res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42)
+ expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u)
+ np.testing.assert_almost_equal(res ** 2, expected)
+
+
+def test_max_sliced_same_dist():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ res = ot.max_sliced_wasserstein_distance(x, x, u, u, 10, seed=rng)
+ np.testing.assert_almost_equal(res, 0.)
+
+
+def test_max_sliced_different_dists():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ y = rng.randn(n, 2)
+
+ res, log = ot.max_sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True)
+ assert res > 0.
+
+
+def test_sliced_backend(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ n_projections = 20
+
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+
+ val0 = ot.sliced_wasserstein_distance(x, y, projections=P)
+
+ val = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+ val2 = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+
+ assert val > 0
+ assert val == val2
+
+ valb = nx.to_numpy(ot.sliced_wasserstein_distance(xb, yb, projections=Pb))
+
+ assert np.allclose(val0, valb)
+
+
+def test_sliced_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+ yb = nx.from_numpy(y, type_as=tp)
+ Pb = nx.from_numpy(P, type_as=tp)
+
+ valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
+
+ nx.assert_same_dtype_device(xb, valb)
+
+
+def test_max_sliced_backend(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ n_projections = 20
+
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+
+ val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P)
+
+ val = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+ val2 = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+
+ assert val > 0
+ assert val == val2
+
+ valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb))
+
+ assert np.allclose(val0, valb)
+
+
+def test_max_sliced_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+ yb = nx.from_numpy(y, type_as=tp)
+ Pb = nx.from_numpy(P, type_as=tp)
+
+ valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
+
+ nx.assert_same_dtype_device(xb, valb)
diff --git a/test/test_smooth.py b/test/test_smooth.py
index 2afa4f8..31e0b2e 100644
--- a/test/test_smooth.py
+++ b/test/test_smooth.py
@@ -25,16 +25,16 @@ def test_smooth_ot_dual():
Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
- # kl regyularisation
+ # kl regularisation
G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
@@ -60,16 +60,16 @@ def test_smooth_ot_semi_dual():
Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
- # kl regyularisation
+ # kl regularisation
G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index 155622c..736df32 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -30,7 +30,7 @@ import ot
def test_stochastic_sag():
# test sag
- n = 15
+ n = 10
reg = 1
numItermax = 30000
rng = np.random.RandomState(0)
@@ -43,11 +43,11 @@ def test_stochastic_sag():
G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag",
numItermax=numItermax)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
- u, G.sum(1), atol=1e-04) # cf convergence sag
+ u, G.sum(1), atol=1e-03) # cf convergence sag
np.testing.assert_allclose(
- u, G.sum(0), atol=1e-04) # cf convergence sag
+ u, G.sum(0), atol=1e-03) # cf convergence sag
#############################################################################
@@ -60,9 +60,9 @@ def test_stochastic_sag():
def test_stochastic_asgd():
# test asgd
- n = 15
+ n = 10
reg = 1
- numItermax = 100000
+ numItermax = 10000
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -73,11 +73,11 @@ def test_stochastic_asgd():
G, log = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
numItermax=numItermax, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
- u, G.sum(1), atol=1e-03) # cf convergence asgd
+ u, G.sum(1), atol=1e-02) # cf convergence asgd
np.testing.assert_allclose(
- u, G.sum(0), atol=1e-03) # cf convergence asgd
+ u, G.sum(0), atol=1e-02) # cf convergence asgd
#############################################################################
@@ -90,9 +90,9 @@ def test_stochastic_asgd():
def test_sag_asgd_sinkhorn():
# test all algorithms
- n = 15
+ n = 10
reg = 1
- nb_iter = 100000
+ nb_iter = 10000
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -105,19 +105,19 @@ def test_sag_asgd_sinkhorn():
numItermax=nb_iter)
G_sinkhorn = ot.sinkhorn(u, u, M, reg)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
- G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-02)
np.testing.assert_allclose(
- G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-02)
np.testing.assert_allclose(
- G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-02)
np.testing.assert_allclose(
- G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-02)
np.testing.assert_allclose(
- G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag
+ G_sag, G_sinkhorn, atol=1e-02) # cf convergence sag
np.testing.assert_allclose(
- G_asgd, G_sinkhorn, atol=1e-03) # cf convergence asgd
+ G_asgd, G_sinkhorn, atol=1e-02) # cf convergence asgd
#############################################################################
@@ -136,7 +136,7 @@ def test_stochastic_dual_sgd():
# test sgd
n = 10
reg = 1
- numItermax = 15000
+ numItermax = 5000
batch_size = 10
rng = np.random.RandomState(0)
@@ -148,7 +148,7 @@ def test_stochastic_dual_sgd():
G, log = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
numItermax=numItermax, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-03) # cf convergence sgd
np.testing.assert_allclose(
@@ -167,7 +167,7 @@ def test_dual_sgd_sinkhorn():
# test all dual algorithms
n = 10
reg = 1
- nb_iter = 15000
+ nb_iter = 5000
batch_size = 10
rng = np.random.RandomState(0)
@@ -181,13 +181,13 @@ def test_dual_sgd_sinkhorn():
G_sinkhorn = ot.sinkhorn(u, u, M, reg)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
- G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-02)
np.testing.assert_allclose(
- G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-02)
np.testing.assert_allclose(
- G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
+ G_sgd, G_sinkhorn, atol=1e-02) # cf convergence sgd
# Test gaussian
n = 30
@@ -206,7 +206,7 @@ def test_dual_sgd_sinkhorn():
G_sinkhorn = ot.sinkhorn(a, b, M, reg)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
np.testing.assert_allclose(
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index dfeaad9..e8349d1 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -115,7 +115,8 @@ def test_stabilized_vs_sinkhorn():
G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
method="sinkhorn_stabilized",
reg_m=reg_m,
- log=True)
+ log=True,
+ verbose=True)
G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method="sinkhorn", log=True)
@@ -138,7 +139,7 @@ def test_unbalanced_barycenter(method):
reg_m = 1.
q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
- method=method, log=True)
+ method=method, log=True, verbose=True)
# check fixed point equations
fi = reg_m / (reg_m + epsilon)
logA = np.log(A + 1e-16)
@@ -173,6 +174,7 @@ def test_barycenter_stabilized_vs_sinkhorn():
reg_m=reg_m, log=True,
tau=100,
method="sinkhorn_stabilized",
+ verbose=True
)
q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
method="sinkhorn",
@@ -182,6 +184,33 @@ def test_barycenter_stabilized_vs_sinkhorn():
q, qstable, atol=1e-05)
+def test_wrong_method():
+
+ n = 10
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n) * 1.5
+
+ M = ot.dist(x, x)
+ epsilon = 1.
+ reg_m = 1.
+
+ with pytest.raises(ValueError):
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
+ reg_m=reg_m,
+ method='badmethod',
+ log=True,
+ verbose=True)
+ with pytest.raises(ValueError):
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method='badmethod',
+ verbose=True)
+
+
def test_implemented_methods():
IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
diff --git a/test/test_utils.py b/test/test_utils.py
index db9cda6..40f4e49 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -4,10 +4,41 @@
#
# License: MIT License
-
import ot
import numpy as np
import sys
+import pytest
+
+
+def test_proj_simplex(nx):
+ n = 10
+ rng = np.random.RandomState(0)
+
+ # test on matrix when projection is done on axis 0
+ x = rng.randn(n, 2)
+ x1 = nx.from_numpy(x)
+
+ # all projections should sum to 1
+ proj = ot.utils.proj_simplex(x1)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
+ # all projections should sum to 3
+ proj = ot.utils.proj_simplex(x1, 3)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = 3 * np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
+ # tets on vector
+ x = rng.randn(n)
+ x1 = nx.from_numpy(x)
+
+ # all projections should sum to 1
+ proj = ot.utils.proj_simplex(x1)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
def test_parmap():
@@ -45,8 +76,8 @@ def test_tic_toc():
def test_kernel():
n = 100
-
- x = np.random.randn(n, 2)
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
K = ot.utils.kernel(x, x)
@@ -67,7 +98,8 @@ def test_dist():
n = 100
- x = np.random.randn(n, 2)
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
D = np.zeros((n, n))
for i in range(n):
@@ -77,9 +109,31 @@ def test_dist():
D2 = ot.dist(x, x)
D3 = ot.dist(x)
+ D4 = ot.dist(x, x, metric='minkowski', p=2)
+
+ assert D4[0, 1] == D4[1, 0]
+
# dist shoul return squared euclidean
- np.testing.assert_allclose(D, D2)
- np.testing.assert_allclose(D, D3)
+ np.testing.assert_allclose(D, D2, atol=1e-14)
+ np.testing.assert_allclose(D, D3, atol=1e-14)
+
+
+def test_dist_backends(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ x1 = nx.from_numpy(x)
+
+ lst_metric = ['euclidean', 'sqeuclidean']
+
+ for metric in lst_metric:
+
+ D = ot.dist(x, x, metric=metric)
+ D1 = ot.dist(x1, x1, metric=metric)
+
+ # low atol because jax forces float32
+ np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)
def test_dist0():
@@ -95,9 +149,11 @@ def test_dots():
n1, n2, n3, n4 = 100, 50, 200, 100
- A = np.random.randn(n1, n2)
- B = np.random.randn(n2, n3)
- C = np.random.randn(n3, n4)
+ rng = np.random.RandomState(0)
+
+ A = rng.randn(n1, n2)
+ B = rng.randn(n2, n3)
+ C = rng.randn(n3, n4)
X1 = ot.utils.dots(A, B, C)
@@ -169,6 +225,13 @@ def test_deprecated_func():
class Class():
pass
+ with pytest.warns(DeprecationWarning):
+ fun()
+
+ with pytest.warns(DeprecationWarning):
+ cl = Class()
+ print(cl)
+
if sys.version_info < (3, 5):
print('Not tested')
else:
@@ -199,4 +262,7 @@ def test_BaseEstimator():
params['first'] = 'spam again'
cl.set_params(**params)
+ with pytest.raises(ValueError):
+ cl.set_params(bibi=10)
+
assert cl.first == 'spam again'