diff options
author | Gard Spreemann <gspr@nonempty.org> | 2021-11-09 17:05:13 +0100 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2021-11-09 17:05:13 +0100 |
commit | a9fdc844907decddf54bed3ebeea8d8b2cf0fc5c (patch) | |
tree | 449a03fce8fafb78b6badd12b6e633f1e5d73a64 /test | |
parent | a16b9471d7114ec08977479b7249efe747702b97 (diff) | |
parent | f1628794d521a8dfa00af383b5e06cd6d34af619 (diff) |
Merge tag '0.8.0' into dfsg/latest
Diffstat (limited to 'test')
-rw-r--r-- | test/conftest.py | 62 | ||||
-rw-r--r-- | test/test_1d_solver.py | 172 | ||||
-rw-r--r-- | test/test_backend.py | 577 | ||||
-rw-r--r-- | test/test_bregman.py | 718 | ||||
-rw-r--r-- | test/test_da.py | 24 | ||||
-rw-r--r-- | test/test_dr.py | 62 | ||||
-rw-r--r-- | test/test_gromov.py | 523 | ||||
-rw-r--r-- | test/test_helpers.py | 26 | ||||
-rw-r--r-- | test/test_optim.py | 103 | ||||
-rw-r--r-- | test/test_ot.py | 183 | ||||
-rwxr-xr-x | test/test_partial.py | 16 | ||||
-rw-r--r-- | test/test_regpath.py | 64 | ||||
-rw-r--r-- | test/test_sliced.py | 213 | ||||
-rw-r--r-- | test/test_smooth.py | 12 | ||||
-rw-r--r-- | test/test_stochastic.py | 52 | ||||
-rw-r--r-- | test/test_unbalanced.py | 33 | ||||
-rw-r--r-- | test/test_utils.py | 84 |
17 files changed, 2600 insertions, 324 deletions
diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..987d98e --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- + +# Configuration file for pytest + +# License: MIT License + +import pytest +from ot.backend import jax +from ot.backend import get_backend_list +import functools + +if jax: + from jax.config import config + config.update("jax_enable_x64", True) + +backend_list = get_backend_list() + + +@pytest.fixture(params=backend_list) +def nx(request): + backend = request.param + + yield backend + + +def skip_arg(arg, value, reason=None, getter=lambda x: x): + if isinstance(arg, tuple) or isinstance(arg, list): + n = len(arg) + else: + arg = (arg, ) + n = 1 + if n != 1 and (isinstance(value, tuple) or isinstance(value, list)): + pass + else: + value = (value, ) + if isinstance(getter, tuple) or isinstance(value, list): + pass + else: + getter = [getter] * n + + if reason is None: + reason = f"Param {arg} should be skipped for value {value}" + + def wrapper(function): + + @functools.wraps(function) + def wrapped(*args, **kwargs): + if all( + arg[i] in kwargs.keys() and getter[i](kwargs[arg[i]]) == value[i] + for i in range(n) + ): + pytest.skip(reason) + return function(*args, **kwargs) + + return wrapped + + return wrapper + + +def pytest_configure(config): + pytest.skip_arg = skip_arg + pytest.skip_backend = functools.partial(skip_arg, "nx", getter=str) diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py new file mode 100644 index 0000000..cb85cb9 --- /dev/null +++ b/test/test_1d_solver.py @@ -0,0 +1,172 @@ +"""Tests for module 1d Wasserstein solver""" + +# Author: Adrien Corenflos <adrien.corenflos@aalto.fi> +# Nicolas Courty <ncourty@irisa.fr> +# +# License: MIT License + +import numpy as np +import pytest + +import ot +from ot.lp import wasserstein_1d + +from ot.backend import get_backend_list +from scipy.stats import wasserstein_distance + +backend_list = get_backend_list() + + +def test_emd_1d_emd2_1d_with_weights(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd(w_u, w_v, M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(w_u, G.sum(1)) + np.testing.assert_allclose(w_v, G.sum(0)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_wasserstein_1d(nx): + from scipy.stats import wasserstein_distance + + rng = np.random.RandomState(0) + + n = 100 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + xb = nx.from_numpy(x) + rho_ub = nx.from_numpy(rho_u) + rho_vb = nx.from_numpy(rho_v) + + # test 1 : wasserstein_1d should be close to scipy W_1 implementation + np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1), + wasserstein_distance(x, x, rho_u, rho_v)) + + # test 2 : wasserstein_1d should be close to one when only translating the support + np.testing.assert_almost_equal(wasserstein_1d(xb, xb + 1, p=2), + 1.) + + # test 3 : arrays test + X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1) + Xb = nx.from_numpy(X) + res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2) + np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) + + +def test_wasserstein_1d_type_devices(nx): + + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + rho_ub = nx.from_numpy(rho_u, type_as=tp) + rho_vb = nx.from_numpy(rho_v, type_as=tp) + + res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) + + nx.assert_same_dtype_device(xb, res) + + +def test_emd_1d_emd2_1d(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd([], [], M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) + np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) + + # check G is similar + np.testing.assert_allclose(G, G_1d, atol=1e-15) + + # check AssertionError is raised if called on non 1d arrays + u = np.random.randn(n, 2) + v = np.random.randn(m, 2) + with pytest.raises(AssertionError): + ot.emd_1d(u, v, [], []) + + +def test_emd1d_type_devices(nx): + + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + rho_ub = nx.from_numpy(rho_u, type_as=tp) + rho_vb = nx.from_numpy(rho_v, type_as=tp) + + emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) + emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) + + nx.assert_same_dtype_device(xb, emd) + nx.assert_same_dtype_device(xb, emd2) diff --git a/test/test_backend.py b/test/test_backend.py new file mode 100644 index 0000000..1832b91 --- /dev/null +++ b/test/test_backend.py @@ -0,0 +1,577 @@ +"""Tests for backend module """ + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# Nicolas Courty <ncourty@irisa.fr> +# +# License: MIT License + +import ot +import ot.backend +from ot.backend import torch, jax + +import pytest + +import numpy as np +from numpy.testing import assert_array_almost_equal_nulp + +from ot.backend import get_backend, get_backend_list, to_numpy + + +def test_get_backend_list(): + + lst = get_backend_list() + + assert len(lst) > 0 + assert isinstance(lst[0], ot.backend.NumpyBackend) + + +def test_to_numpy(nx): + + v = nx.zeros(10) + M = nx.ones((10, 10)) + + v2 = to_numpy(v) + assert isinstance(v2, np.ndarray) + + v2, M2 = to_numpy(v, M) + assert isinstance(M2, np.ndarray) + + +def test_get_backend(): + + A = np.zeros((3, 2)) + B = np.zeros((3, 1)) + + nx = get_backend(A) + assert nx.__name__ == 'numpy' + + nx = get_backend(A, B) + assert nx.__name__ == 'numpy' + + # error if no parameters + with pytest.raises(ValueError): + get_backend() + + # error if unknown types + with pytest.raises(ValueError): + get_backend(1, 2.0) + + # test torch + if torch: + + A2 = torch.from_numpy(A) + B2 = torch.from_numpy(B) + + nx = get_backend(A2) + assert nx.__name__ == 'torch' + + nx = get_backend(A2, B2) + assert nx.__name__ == 'torch' + + # test not unique types in input + with pytest.raises(ValueError): + get_backend(A, B2) + + if jax: + + A2 = jax.numpy.array(A) + B2 = jax.numpy.array(B) + + nx = get_backend(A2) + assert nx.__name__ == 'jax' + + nx = get_backend(A2, B2) + assert nx.__name__ == 'jax' + + # test not unique types in input + with pytest.raises(ValueError): + get_backend(A, B2) + + +def test_convert_between_backends(nx): + + A = np.zeros((3, 2)) + B = np.zeros((3, 1)) + + A2 = nx.from_numpy(A) + B2 = nx.from_numpy(B) + + assert isinstance(A2, nx.__type__) + assert isinstance(B2, nx.__type__) + + nx2 = get_backend(A2, B2) + + assert nx2.__name__ == nx.__name__ + + assert_array_almost_equal_nulp(nx.to_numpy(A2), A) + assert_array_almost_equal_nulp(nx.to_numpy(B2), B) + + +def test_empty_backend(): + + rnd = np.random.RandomState(0) + M = rnd.randn(10, 3) + v = rnd.randn(3) + + nx = ot.backend.Backend() + + with pytest.raises(NotImplementedError): + nx.from_numpy(M) + with pytest.raises(NotImplementedError): + nx.to_numpy(M) + with pytest.raises(NotImplementedError): + nx.set_gradients(0, 0, 0) + with pytest.raises(NotImplementedError): + nx.zeros((10, 3)) + with pytest.raises(NotImplementedError): + nx.ones((10, 3)) + with pytest.raises(NotImplementedError): + nx.arange(10, 1, 2) + with pytest.raises(NotImplementedError): + nx.full((10, 3), 3.14) + with pytest.raises(NotImplementedError): + nx.eye((10, 3)) + with pytest.raises(NotImplementedError): + nx.sum(M) + with pytest.raises(NotImplementedError): + nx.cumsum(M) + with pytest.raises(NotImplementedError): + nx.max(M) + with pytest.raises(NotImplementedError): + nx.min(M) + with pytest.raises(NotImplementedError): + nx.maximum(v, v) + with pytest.raises(NotImplementedError): + nx.minimum(v, v) + with pytest.raises(NotImplementedError): + nx.abs(M) + with pytest.raises(NotImplementedError): + nx.log(M) + with pytest.raises(NotImplementedError): + nx.exp(M) + with pytest.raises(NotImplementedError): + nx.sqrt(M) + with pytest.raises(NotImplementedError): + nx.power(v, 2) + with pytest.raises(NotImplementedError): + nx.dot(v, v) + with pytest.raises(NotImplementedError): + nx.norm(M) + with pytest.raises(NotImplementedError): + nx.exp(M) + with pytest.raises(NotImplementedError): + nx.any(M) + with pytest.raises(NotImplementedError): + nx.isnan(M) + with pytest.raises(NotImplementedError): + nx.isinf(M) + with pytest.raises(NotImplementedError): + nx.einsum('ij->i', M) + with pytest.raises(NotImplementedError): + nx.sort(M) + with pytest.raises(NotImplementedError): + nx.argsort(M) + with pytest.raises(NotImplementedError): + nx.searchsorted(v, v) + with pytest.raises(NotImplementedError): + nx.flip(M) + with pytest.raises(NotImplementedError): + nx.outer(v, v) + with pytest.raises(NotImplementedError): + nx.clip(M, -1, 1) + with pytest.raises(NotImplementedError): + nx.repeat(M, 0, 1) + with pytest.raises(NotImplementedError): + nx.take_along_axis(M, v, 0) + with pytest.raises(NotImplementedError): + nx.concatenate([v, v]) + with pytest.raises(NotImplementedError): + nx.zero_pad(M, v) + with pytest.raises(NotImplementedError): + nx.argmax(M) + with pytest.raises(NotImplementedError): + nx.mean(M) + with pytest.raises(NotImplementedError): + nx.std(M) + with pytest.raises(NotImplementedError): + nx.linspace(0, 1, 50) + with pytest.raises(NotImplementedError): + nx.meshgrid(v, v) + with pytest.raises(NotImplementedError): + nx.diag(M) + with pytest.raises(NotImplementedError): + nx.unique([M, M]) + with pytest.raises(NotImplementedError): + nx.logsumexp(M) + with pytest.raises(NotImplementedError): + nx.stack([M, M]) + with pytest.raises(NotImplementedError): + nx.reshape(M, (5, 3, 2)) + with pytest.raises(NotImplementedError): + nx.seed(42) + with pytest.raises(NotImplementedError): + nx.rand() + with pytest.raises(NotImplementedError): + nx.randn() + nx.coo_matrix(M, M, M) + with pytest.raises(NotImplementedError): + nx.issparse(M) + with pytest.raises(NotImplementedError): + nx.tocsr(M) + with pytest.raises(NotImplementedError): + nx.eliminate_zeros(M) + with pytest.raises(NotImplementedError): + nx.todense(M) + with pytest.raises(NotImplementedError): + nx.where(M, M, M) + with pytest.raises(NotImplementedError): + nx.copy(M) + with pytest.raises(NotImplementedError): + nx.allclose(M, M) + + +def test_func_backends(nx): + + rnd = np.random.RandomState(0) + M = rnd.randn(10, 3) + v = rnd.randn(3) + val = np.array([1.0]) + + # Sparse tensors test + sp_row = np.array([0, 3, 1, 0, 3]) + sp_col = np.array([0, 3, 1, 2, 2]) + sp_data = np.array([4, 5, 7, 9, 0]) + + lst_tot = [] + + for nx in [ot.backend.NumpyBackend(), nx]: + + print('Backend: ', nx.__name__) + + lst_b = [] + lst_name = [] + + Mb = nx.from_numpy(M) + vb = nx.from_numpy(v) + + val = nx.from_numpy(val) + + sp_rowb = nx.from_numpy(sp_row) + sp_colb = nx.from_numpy(sp_col) + sp_datab = nx.from_numpy(sp_data) + + A = nx.set_gradients(val, v, v) + + lst_b.append(nx.to_numpy(A)) + lst_name.append('set_gradients') + + A = nx.zeros((10, 3)) + A = nx.zeros((10, 3), type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('zeros') + + A = nx.ones((10, 3)) + A = nx.ones((10, 3), type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('ones') + + A = nx.arange(10, 1, 2) + lst_b.append(nx.to_numpy(A)) + lst_name.append('arange') + + A = nx.full((10, 3), 3.14) + A = nx.full((10, 3), 3.14, type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('full') + + A = nx.eye(10, 3) + A = nx.eye(10, 3, type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('eye') + + A = nx.sum(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sum') + + A = nx.sum(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sum(axis)') + + A = nx.cumsum(Mb, 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('cumsum(axis)') + + A = nx.max(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('max') + + A = nx.max(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('max(axis)') + + A = nx.min(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('min') + + A = nx.min(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('min(axis)') + + A = nx.maximum(vb, 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('maximum') + + A = nx.minimum(vb, 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('minimum') + + A = nx.abs(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('abs') + + A = nx.log(A) + lst_b.append(nx.to_numpy(A)) + lst_name.append('log') + + A = nx.exp(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('exp') + + A = nx.sqrt(nx.abs(Mb)) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sqrt') + + A = nx.power(Mb, 2) + lst_b.append(nx.to_numpy(A)) + lst_name.append('power') + + A = nx.dot(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('dot(v,v)') + + A = nx.dot(Mb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('dot(M,v)') + + A = nx.dot(Mb, Mb.T) + lst_b.append(nx.to_numpy(A)) + lst_name.append('dot(M,M)') + + A = nx.norm(vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('norm') + + A = nx.any(vb > 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('any') + + A = nx.isnan(vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('isnan') + + A = nx.isinf(vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('isinf') + + A = nx.einsum('ij->i', Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('einsum(ij->i)') + + A = nx.einsum('ij,j->i', Mb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('nx.einsum(ij,j->i)') + + A = nx.einsum('ij->i', Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('nx.einsum(ij->i)') + + A = nx.sort(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sort') + + A = nx.argsort(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argsort') + + A = nx.searchsorted(Mb, Mb, 'right') + lst_b.append(nx.to_numpy(A)) + lst_name.append('searchsorted') + + A = nx.flip(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('flip') + + A = nx.outer(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('outer') + + A = nx.clip(vb, 0, 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('clip') + + A = nx.repeat(Mb, 0) + A = nx.repeat(Mb, 2, -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('repeat') + + A = nx.take_along_axis(vb, nx.arange(3), -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('take_along_axis') + + A = nx.concatenate((Mb, Mb), -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('concatenate') + + A = nx.zero_pad(Mb, len(Mb.shape) * [(3, 3)]) + lst_b.append(nx.to_numpy(A)) + lst_name.append('zero_pad') + + A = nx.argmax(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argmax') + + A = nx.mean(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('mean') + + A = nx.std(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('std') + + A = nx.linspace(0, 1, 50) + lst_b.append(nx.to_numpy(A)) + lst_name.append('linspace') + + X, Y = nx.meshgrid(vb, vb) + lst_b.append(np.stack([nx.to_numpy(X), nx.to_numpy(Y)])) + lst_name.append('meshgrid') + + A = nx.diag(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('diag2D') + + A = nx.diag(vb, 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('diag1D') + + A = nx.unique(nx.from_numpy(np.stack([M, M]))) + lst_b.append(nx.to_numpy(A)) + lst_name.append('unique') + + A = nx.logsumexp(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('logsumexp') + + A = nx.stack([Mb, Mb]) + lst_b.append(nx.to_numpy(A)) + lst_name.append('stack') + + A = nx.reshape(Mb, (5, 3, 2)) + lst_b.append(nx.to_numpy(A)) + lst_name.append('reshape') + + sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4)) + nx.todense(Mb) + lst_b.append(nx.to_numpy(nx.todense(sp_Mb))) + lst_name.append('coo_matrix') + + assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)' + assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)' + + A = nx.tocsr(sp_Mb) + lst_b.append(nx.to_numpy(nx.todense(A))) + lst_name.append('tocsr') + + A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.) + lst_b.append(nx.to_numpy(A)) + lst_name.append('eliminate_zeros (dense)') + + A = nx.eliminate_zeros(sp_Mb) + lst_b.append(nx.to_numpy(nx.todense(A))) + lst_name.append('eliminate_zeros (sparse)') + + A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('where') + + A = nx.copy(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('copy') + + assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)' + assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)' + + lst_tot.append(lst_b) + + lst_np = lst_tot[0] + lst_b = lst_tot[1] + + for a1, a2, name in zip(lst_np, lst_b, lst_name): + if not np.allclose(a1, a2): + print('Assert fail on: ', name) + assert np.allclose(a1, a2, atol=1e-7) + + +def test_random_backends(nx): + + tmp_u = nx.rand() + + assert tmp_u < 1 + + tmp_n = nx.randn() + + nx.seed(0) + M1 = nx.to_numpy(nx.rand(5, 2)) + nx.seed(0) + M2 = nx.to_numpy(nx.rand(5, 2, type_as=tmp_n)) + + assert np.all(M1 >= 0) + assert np.all(M1 < 1) + assert M1.shape == (5, 2) + assert np.allclose(M1, M2) + + nx.seed(0) + M1 = nx.to_numpy(nx.randn(5, 2)) + nx.seed(0) + M2 = nx.to_numpy(nx.randn(5, 2, type_as=tmp_u)) + + nx.seed(42) + v1 = nx.randn() + v2 = nx.randn() + assert v1 != v2 + + +def test_gradients_backends(): + + rnd = np.random.RandomState(0) + v = rnd.randn(10) + c = rnd.randn() + e = rnd.randn() + + if torch: + + nx = ot.backend.TorchBackend() + + v2 = torch.tensor(v, requires_grad=True) + c2 = torch.tensor(c, requires_grad=True) + + val = c2 * torch.sum(v2 * v2) + + val2 = nx.set_gradients(val, (v2, c2), (v2, c2)) + + val2.backward() + + assert torch.equal(v2.grad, v2) + assert torch.equal(c2.grad, c2) + + if jax: + nx = ot.backend.JaxBackend() + with jax.checking_leaks(): + def fun(a, b, d): + val = b * nx.sum(a ** 4) + d + return nx.set_gradients(val, (a, b, d), (a, b, 2 * d)) + grad_val = jax.grad(fun, argnums=(0, 1, 2))(v, c, e) + + np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4) + np.testing.assert_allclose(grad_val[0], v, atol=1e-4) + np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4) diff --git a/test/test_bregman.py b/test/test_bregman.py index 6aa4e08..830052d 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -2,15 +2,21 @@ # Author: Remi Flamary <remi.flamary@unice.fr> # Kilian Fatras <kilian.fatras@irisa.fr> +# Quang Huy Tran <quang-huy.tran@univ-ubs.fr> # # License: MIT License +from itertools import product + import numpy as np -import ot import pytest +import ot +from ot.backend import torch + -def test_sinkhorn(): +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_sinkhorn(verbose, warn): # test sinkhorn n = 100 rng = np.random.RandomState(0) @@ -20,14 +26,189 @@ def test_sinkhorn(): M = ot.dist(x, x) - G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) + G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + with pytest.warns(UserWarning): + ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) +def test_convergence_warning(method): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + A = np.asarray([a1, a2]).T + M = ot.utils.dist0(n) + + with pytest.warns(UserWarning): + ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1) + + if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]: + with pytest.warns(UserWarning): + ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) + with pytest.warns(UserWarning): + ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) + + +def test_not_impemented_method(): + # test sinkhorn + w = 10 + n = w ** 2 + rng = np.random.RandomState(42) + A_img = rng.rand(2, w, w) + A_flat = A_img.reshape(n, 2) + a1, a2 = A_flat.T + M_flat = ot.utils.dist0(n) + not_implemented = "new_method" + reg = 0.01 + with pytest.raises(ValueError): + ot.sinkhorn(a1, a2, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.sinkhorn2(a1, a2, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.barycenter(A_flat, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.barycenter_debiased(A_flat, M_flat, reg, + method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.convolutional_barycenter2d(A_img, reg, + method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, + method=not_implemented) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_nan_warning(method): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + + M = ot.utils.dist0(n) + reg = 0 + with pytest.warns(UserWarning): + # warn set to False to avoid catching a convergence warning instead + ot.sinkhorn(a1, a2, M, reg, method=method, warn=False) + + +def test_sinkhorn_stabilization(): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + M = ot.utils.dist0(n) + reg = 1e-5 + loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log") + loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized") + np.testing.assert_allclose( + loss1, loss2, atol=1e-06) # cf convergence sinkhorn + + +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_log"], + [True, False], [True, False])) +def test_sinkhorn_multi_b(method, verbose, warn): + # test sinkhorn + n = 10 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10, + log=True) + + loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10, + verbose=verbose, warn=warn) for k in range(3)] + # check constraints + np.testing.assert_allclose( + loss0, loss, atol=1e-4) # cf convergence sinkhorn + + +def test_sinkhorn_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.sinkhorn(a, a, M, 1) + + ab = nx.from_numpy(a) + M_nx = nx.from_numpy(M) + + Gb = ot.sinkhorn(ab, ab, M_nx, 1) + + np.allclose(G, nx.to_numpy(Gb)) + + +def test_sinkhorn2_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.sinkhorn(a, a, M, 1) + + ab = nx.from_numpy(a) + M_nx = nx.from_numpy(M) + + Gb = ot.sinkhorn2(ab, ab, M_nx, 1) + + np.allclose(G, nx.to_numpy(Gb)) + + +def test_sinkhorn2_gradients(): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + if torch: + + a1 = torch.tensor(a, requires_grad=True) + b1 = torch.tensor(a, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.sinkhorn2(a1, b1, M1, 1) + + val.backward() + + assert a1.shape == a1.grad.shape + assert b1.shape == b1.grad.shape + assert M1.shape == M1.grad.shape + def test_sinkhorn_empty(): # test sinkhorn @@ -39,21 +220,27 @@ def test_sinkhorn_empty(): M = ot.dist(x, x) + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log", + verbose=True, log=True) + # check constraints + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method='sinkhorn_stabilized', verbose=True, log=True) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) G, log = ot.sinkhorn( [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling', verbose=True, log=True) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) @@ -61,7 +248,8 @@ def test_sinkhorn_empty(): ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True) -def test_sinkhorn_variants(): +@pytest.skip_backend("jax") +def test_sinkhorn_variants(nx): # test sinkhorn n = 100 rng = np.random.RandomState(0) @@ -71,22 +259,131 @@ def test_sinkhorn_variants(): M = ot.dist(x, x) - G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10) - Ges = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10) - G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10) + ub = nx.from_numpy(u) + M_nx = nx.from_numpy(M) + + G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Ges = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) + G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) - print(G0, G_green) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) +@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str) +@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str) +def test_sinkhorn_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub = nx.from_numpy(u, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, Gb) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) +def test_sinkhorn2_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub = nx.from_numpy(u, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, lossb) + + +@pytest.skip_backend("jax") +def test_sinkhorn_variants_multi_b(nx): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + ub = nx.from_numpy(u) + bb = nx.from_numpy(b) + M_nx = nx.from_numpy(M) + + G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + + # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Gs, atol=1e-05) + + +@pytest.skip_backend("jax") +def test_sinkhorn2_variants_multi_b(nx): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + ub = nx.from_numpy(u) + bb = nx.from_numpy(b) + M_nx = nx.from_numpy(M) + + G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + + # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Gs, atol=1e-05) def test_sinkhorn_variants_log(): # test sinkhorn - n = 100 + n = 50 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -95,20 +392,87 @@ def test_sinkhorn_variants_log(): M = ot.dist(x, x) G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) + u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,) G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Gl, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) - print(G0, G_green) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_barycenter(method): +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_sinkhorn_variants_log_multib(verbose, warn): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) + Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) + + # check values + np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Gl, atol=1e-05) + + +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter(nx, method, verbose, warn): + n_bins = 100 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) + weights_nx = nx.from_numpy(weights) + reg = 1e-2 + + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method) + else: + # wasserstein + bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) + + ot.bregman.barycenter(A_nx, M_nx, reg, log=True) + + +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter_debiased(nx, method, verbose, warn): n_bins = 100 # nb bins # Gaussian distributions @@ -125,16 +489,61 @@ def test_barycenter(method): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) + weights_nx = nx.from_numpy(weights) + # wasserstein reg = 1e-2 - bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True) + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) + else: + bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, + verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) - np.testing.assert_allclose(1, np.sum(bary_wass)) + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) - ot.bregman.barycenter(A, M, reg, log=True, verbose=True) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_convergence_warning_barycenters(method): + w = 10 + n_bins = w ** 2 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + A_img = A.reshape(2, w, w) + A_img /= A_img.sum((1, 2))[:, None, None] + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() -def test_barycenter_stabilization(): + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + reg = 0.1 + with pytest.warns(UserWarning): + ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.convolutional_barycenter2d(A_img, reg, weights, + method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights, + method=method, numItermax=1) + + +def test_barycenter_stabilization(nx): n_bins = 100 # nb bins # Gaussian distributions @@ -151,22 +560,64 @@ def test_barycenter_stabilization(): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) + weights_b = nx.from_numpy(weights) + # wasserstein reg = 1e-2 - bar_stable = ot.bregman.barycenter(A, M, reg, weights, - method="sinkhorn_stabilized", - stopThr=1e-8, verbose=True) - bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", - stopThr=1e-8, verbose=True) + bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) + bar_stable = nx.to_numpy(ot.bregman.barycenter( + A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized", + stopThr=1e-8, verbose=True + )) + bar = nx.to_numpy(ot.bregman.barycenter( + A_nx, M_nx, reg, weights_b, method="sinkhorn", + stopThr=1e-8, verbose=True + )) np.testing.assert_allclose(bar, bar_stable) + np.testing.assert_allclose(bar, bar_np) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d(nx, method): + size = 20 # size of a square image + a1 = np.random.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) + a2 = np.random.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + A_nx = nx.from_numpy(A) -def test_wasserstein_bary_2d(): - size = 100 # size of a square image - a1 = np.random.randn(size, size) + # wasserstein + reg = 1e-2 + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) + else: + bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_debiased(nx, method): + size = 20 # size of a square image + a1 = np.random.rand(size, size) a1 += a1.min() a1 = a1 / np.sum(a1) - a2 = np.random.randn(size, size) + a2 = np.random.rand(size, size) a2 += a2.min() a2 = a2 / np.sum(a2) # creating matrix A containing all distributions @@ -174,17 +625,25 @@ def test_wasserstein_bary_2d(): A[0, :, :] = a1 A[1, :, :] = a2 + A_nx = nx.from_numpy(A) + # wasserstein reg = 1e-2 - bary_wass = ot.bregman.convolutional_barycenter2d(A, reg) + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) + else: + bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) - np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) -def test_unmix(): +def test_unmix(nx): n_bins = 50 # nb bins # Gaussian distributions @@ -204,41 +663,58 @@ def test_unmix(): M0 /= M0.max() h0 = ot.unif(2) + ab = nx.from_numpy(a) + Db = nx.from_numpy(D) + M_nx = nx.from_numpy(M) + M0b = nx.from_numpy(M0) + h0b = nx.from_numpy(h0) + # wasserstein reg = 1e-3 - um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, ) + um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) + um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) + np.testing.assert_allclose(um, um_np) - ot.bregman.unmix(a, D, M, M0, h0, reg, + ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01, log=True, verbose=True) -def test_empirical_sinkhorn(): +def test_empirical_sinkhorn(nx): # test sinkhorn - n = 100 + n = 10 a = ot.unif(n) b = ot.unif(n) - X_s = np.reshape(np.arange(n), (n, 1)) - X_t = np.reshape(np.arange(0, n), (n, 1)) + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='minkowski') + M_m = ot.dist(X_s, X_t, metric='euclidean') + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + X_sb = nx.from_numpy(X_s) + X_tb = nx.from_numpy(X_t) + M_nx = nx.from_numpy(M, type_as=ab) + M_mb = nx.from_numpy(M_m, type_as=ab) - G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) - sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) - G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True) - sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True) + G_log = nx.to_numpy(G_log) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) + sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski') - sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) + sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1) - loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) - # check constratints + # check constraints np.testing.assert_allclose( sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian np.testing.assert_allclose( @@ -254,34 +730,98 @@ def test_empirical_sinkhorn(): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) -def test_empirical_sinkhorn_divergence(): - # Test sinkhorn divergence +def test_lazy_empirical_sinkhorn(nx): + # test sinkhorn n = 10 a = ot.unif(n) b = ot.unif(n) + numIterMax = 1000 + + X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) + X_t = np.reshape(np.arange(0, n, dtype=np.float64), (n, 1)) + M = ot.dist(X_s, X_t) + M_m = ot.dist(X_s, X_t, metric='euclidean') + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + X_sb = nx.from_numpy(X_s) + X_tb = nx.from_numpy(X_t) + M_nx = nx.from_numpy(M, type_as=ab) + M_mb = nx.from_numpy(M_m, type_as=ab) + + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) + f, g = nx.to_numpy(f), nx.to_numpy(g) + G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) + + f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + f, g = nx.to_numpy(f), nx.to_numpy(g) + G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) + sinkhorn_log = nx.to_numpy(sinkhorn_log) + + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = nx.to_numpy(f), nx.to_numpy(g) + G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) + sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) + + loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) + + # check constraints + np.testing.assert_allclose( + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian + np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + + +def test_empirical_sinkhorn_divergence(nx): + # Test sinkhorn divergence + n = 10 + a = np.linspace(1, n, n) + a /= a.sum() + b = ot.unif(n) X_s = np.reshape(np.arange(n), (n, 1)) X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1)) M = ot.dist(X_s, X_t) M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1) - sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1)) - - emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True) - sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True) - sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True) - sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True) - sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b) - - # check constratints + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + X_sb = nx.from_numpy(X_s) + X_tb = nx.from_numpy(X_t) + M_nx = nx.from_numpy(M, type_as=ab) + M_sb = nx.from_numpy(M_s, type_as=ab) + M_tb = nx.from_numpy(M_t, type_as=ab) + + emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) + sinkhorn_div = nx.to_numpy( + ot.sinkhorn2(ab, bb, M_nx, 1) + - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) + - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) + ) + emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) + + # check constraints + np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) np.testing.assert_allclose( emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn - np.testing.assert_allclose( - emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn + + ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True) -def test_stabilized_vs_sinkhorn_multidim(): +def test_stabilized_vs_sinkhorn_multidim(nx): # test if stable version matches sinkhorn # for multidimensional inputs n = 100 @@ -297,12 +837,21 @@ def test_stabilized_vs_sinkhorn_multidim(): M = ot.utils.dist0(n) M /= np.median(M) epsilon = 0.1 - G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon, + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + M_nx = nx.from_numpy(M, type_as=ab) + + G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) + G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, method="sinkhorn_stabilized", log=True) - G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon, + G = nx.to_numpy(G) + G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon, method="sinkhorn", log=True) + G2 = nx.to_numpy(G2) + np.testing.assert_allclose(G_np, G2) np.testing.assert_allclose(G, G2) @@ -320,8 +869,9 @@ def test_implemented_methods(): # make dists unbalanced b = ot.utils.unif(n) A = rng.rand(n, 2) + A /= A.sum(0, keepdims=True) M = ot.dist(x, x) - epsilon = 1. + epsilon = 1.0 for method in IMPLEMENTED_METHODS: ot.bregman.sinkhorn(a, b, M, epsilon, method=method) @@ -338,7 +888,9 @@ def test_implemented_methods(): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) -def test_screenkhorn(): +@pytest.skip_backend("jax") +@pytest.mark.filterwarnings("ignore:Bottleneck") +def test_screenkhorn(nx): # test screenkhorn rng = np.random.RandomState(0) n = 100 @@ -347,17 +899,31 @@ def test_screenkhorn(): x = rng.randn(n, 2) M = ot.dist(x, x) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + M_nx = nx.from_numpy(M, type_as=ab) + + # np sinkhorn + G_sink_np = ot.sinkhorn(a, b, M, 1e-03) # sinkhorn - G_sink = ot.sinkhorn(a, b, M, 1e-03) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03)) # screenkhorn - G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True)) # check marginals + np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) -def test_convolutional_barycenter_non_square(): +def test_convolutional_barycenter_non_square(nx): # test for image with height not equal width A = np.ones((2, 2, 3)) / (2 * 3) - b = ot.bregman.convolutional_barycenter2d(A, 1e-03) + A_nx = nx.from_numpy(A) + + b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03) + b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, 1e-03)) + + np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) + np.testing.assert_allclose(b, b_np) diff --git a/test/test_da.py b/test/test_da.py index 3b28119..9f2bb50 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -6,11 +6,18 @@ import numpy as np from numpy.testing import assert_allclose, assert_equal +import pytest import ot from ot.datasets import make_data_classif from ot.utils import unif +try: # test if cudamat installed + import sklearn # noqa: F401 + nosklearn = False +except ImportError: + nosklearn = True + def test_sinkhorn_lpl1_transport_class(): """test_sinkhorn_transport @@ -99,8 +106,8 @@ def test_sinkhorn_l1l2_transport_class(): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 100 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -441,8 +448,8 @@ def test_mapping_transport_class(): """test_mapping_transport """ - ns = 60 - nt = 120 + ns = 20 + nt = 30 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -558,6 +565,14 @@ def test_mapping_transport_class(): otda.fit(Xs=Xs, Xt=Xt) assert len(otda.log_.keys()) != 0 + # check that it does not crash when derphi is very close to 0 + np.random.seed(39) + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + otda = ot.da.MappingTransport(kernel="gaussian", bias=False) + otda.fit(Xs=Xs, Xt=Xt) + np.random.seed(None) + def test_linear_mapping(): ns = 150 @@ -691,6 +706,7 @@ def test_jcpot_barycenter(): np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3) +@pytest.mark.skipif(nosklearn, reason="No sklearn available") def test_emd_laplace_class(): """test_emd_laplace_transport """ diff --git a/test/test_dr.py b/test/test_dr.py index c5df287..741f2ad 100644 --- a/test/test_dr.py +++ b/test/test_dr.py @@ -1,6 +1,7 @@ """Tests for module dr on Dimensionality Reduction """ # Author: Remi Flamary <remi.flamary@unice.fr> +# Minhui Huang <mhhuang@ucdavis.edu> # # License: MIT License @@ -57,3 +58,64 @@ def test_wda(): projwda(xs) np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) + + +@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") +def test_wda_normalized(): + + n_samples = 100 # nb samples in source and target datasets + np.random.seed(0) + + # generate gaussian dataset + xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples) + + n_features_noise = 8 + + xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise))) + + p = 2 + + P0 = np.random.randn(10, p) + P0 /= P0.sum(0, keepdims=True) + + Pwda, projwda = ot.dr.wda(xs, ys, p, maxiter=10, P0=P0, normalize=True) + + projwda(xs) + + np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) + + +@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") +def test_prw(): + d = 100 # Dimension + n = 100 # Number samples + k = 3 # Subspace dimension + dim = 3 + + def fragmented_hypercube(n, d, dim): + assert dim <= d + assert dim >= 1 + assert dim == int(dim) + + a = (1. / n) * np.ones(n) + b = (1. / n) * np.ones(n) + + # First measure : uniform on the hypercube + X = np.random.uniform(-1, 1, size=(n, d)) + + # Second measure : fragmentation + tmp_y = np.random.uniform(-1, 1, size=(n, d)) + Y = tmp_y + 2 * np.sign(tmp_y) * np.array(dim * [1] + (d - dim) * [0]) + return a, b, X, Y + + a, b, X, Y = fragmented_hypercube(n, d, dim) + + tau = 0.002 + reg = 0.2 + + pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, reg=reg, k=k, maxiter=1000, verbose=1) + + U0 = np.random.randn(d, k) + U0, _ = np.linalg.qr(U0) + + pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, U0=U0, reg=reg, k=k, maxiter=1000, verbose=1) diff --git a/test/test_gromov.py b/test/test_gromov.py index 43da9fc..c4bc04c 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -8,9 +8,13 @@ import numpy as np
import ot
+from ot.backend import NumpyBackend
+from ot.backend import torch
+import pytest
-def test_gromov():
+
+def test_gromov(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -29,37 +33,121 @@ def test_gromov(): C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True)
+ Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True))
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
- np.testing.assert_allclose(
- G, np.flipud(Id), atol=1e-04)
+ np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04)
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
- np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
+ np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1)
- np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False
+ np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06)
+ np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_gromov_dtype_device(nx):
+ # setup
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ C1b = nx.from_numpy(C1, type_as=tp)
+ C2b = nx.from_numpy(C2, type_as=tp)
+ pb = nx.from_numpy(p, type_as=tp)
+ qb = nx.from_numpy(q, type_as=tp)
+
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
+def test_gromov2_gradients():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ p1 = torch.tensor(p, requires_grad=True)
+ q1 = torch.tensor(q, requires_grad=True)
+ C11 = torch.tensor(C1, requires_grad=True)
+ C12 = torch.tensor(C2, requires_grad=True)
+
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1)
+
+ val.backward()
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
-def test_entropic_gromov():
+
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_entropic_gromov(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -78,85 +166,278 @@ def test_entropic_gromov(): C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
G = ot.gromov.entropic_gromov_wasserstein(
C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
+ Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ ))
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
gw, log = ot.gromov.entropic_gromov_wasserstein2(
C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+ gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
-def test_gromov_barycenter():
- ns = 50
- nt = 60
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_entropic_gromov_dtype_device(nx):
+ # setup
+ n_samples = 50 # nb samples
- Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
- Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
- C1 = ot.dist(Xs)
- C2 = ot.dist(Xt)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
- n_samples = 3
- Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'square_loss', # 5e-4,
- max_iter=100, tol=1e-3,
- verbose=True)
- np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ xt = xs[::-1].copy()
- Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'kl_loss', # 5e-4,
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
-def test_gromov_entropic_barycenter():
- ns = 50
- nt = 60
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ C1b = nx.from_numpy(C1, type_as=tp)
+ C2b = nx.from_numpy(C2, type_as=tp)
+ pb = nx.from_numpy(p, type_as=tp)
+ qb = nx.from_numpy(q, type_as=tp)
+
+ Gb = ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ )
+ gw_valb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ )
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
+def test_pointwise_gromov(nx):
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
+ def loss(x, y):
+ return np.abs(x - y)
+
+ def lossb(x, y):
+ return nx.abs(x - y)
+
+ G, log = ot.gromov.pointwise_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42)
+ G = NumpyBackend().todense(G)
+ Gb, logb = ot.gromov.pointwise_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(nx.todense(Gb))
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08)
+
+ G, log = ot.gromov.pointwise_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
+ G = NumpyBackend().todense(G)
+ Gb, logb = ot.gromov.pointwise_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(nx.todense(Gb))
+
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8)
+
+
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_sampled_gromov(nx):
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0], dtype=np.float64)
+ cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64)
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
+ def loss(x, y):
+ return np.abs(x - y)
+
+ def lossb(x, y):
+ return nx.abs(x - y)
+
+ G, log = ot.gromov.sampled_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ Gb, logb = ot.gromov.sampled_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(Gb)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08)
+
+
+def test_gromov_barycenter(nx):
+ ns = 10
+ nt = 20
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
-
+ p1 = ot.unif(ns)
+ p2 = ot.unif(nt)
n_samples = 3
- Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'square_loss', 2e-3,
- max_iter=100, tol=1e-3,
- verbose=True)
- np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ p = ot.unif(n_samples)
- Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'kl_loss', 2e-3,
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Cb = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ )
+ Cbb = nx.to_numpy(ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ ))
+ np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
+ np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, random_state=42
+ )
+ Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, random_state=42
+ ))
+ np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
+ np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+
+
+@pytest.mark.filterwarnings("ignore:divide")
+def test_gromov_entropic_barycenter(nx):
+ ns = 10
+ nt = 20
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
-def test_fgw():
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+ p1 = ot.unif(ns)
+ p2 = ot.unif(nt)
+ n_samples = 2
+ p = ot.unif(n_samples)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Cb = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42
+ )
+ Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42
+ ))
+ np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
+ np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
+ )
+ Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
+ ))
+ np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
+ np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+
+
+def test_fgw(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -181,33 +462,85 @@ def test_fgw(): M = ot.dist(ys, yt)
M /= M.max()
- G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5)
+ Mb = nx.from_numpy(M)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
- # check constratints
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ Gb = nx.to_numpy(Gb)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence fgw
+ p, Gb.sum(1), atol=1e-04) # cf convergence fgw
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence fgw
+ q, Gb.sum(0), atol=1e-04) # cf convergence fgw
Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
np.testing.assert_allclose(
- G, np.flipud(Id), atol=1e-04) # cf convergence gromov
+ Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov
fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
- np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(fgw, fgwb, atol=1e-08)
+ np.testing.assert_allclose(fgwb, 0, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_fgw2_gradients():
+ n_samples = 50 # nb samples
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
-def test_fgw_barycenter():
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ M = ot.dist(xs, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ p1 = torch.tensor(p, requires_grad=True)
+ q1 = torch.tensor(q, requires_grad=True)
+ C11 = torch.tensor(C1, requires_grad=True)
+ C12 = torch.tensor(C2, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
+
+ val.backward()
+
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
+
+
+def test_fgw_barycenter(nx):
np.random.seed(42)
ns = 50
@@ -221,30 +554,44 @@ def test_fgw_barycenter(): C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
-
+ p1, p2 = ot.unif(ns), ot.unif(nt)
n_samples = 3
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ p = ot.unif(n_samples)
+
+ ysb = nx.from_numpy(ys)
+ ytb = nx.from_numpy(yt)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Xb, Cb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False,
+ fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345
+ )
xalea = np.random.randn(n_samples, 2)
init_C = ot.dist(xalea, xalea)
-
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5,
- fixed_structure=True, init_C=init_C, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ init_Cb = nx.from_numpy(init_C)
+
+ Xb, Cb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=[.5, .5],
+ alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False,
+ p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3
+ )
+ Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
init_X = np.random.randn(n_samples, ys.shape[1])
-
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=True, init_X=init_X,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ init_Xb = nx.from_numpy(init_X)
+
+ Xb, Cb, logb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_Xb,
+ p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, log=True, random_state=98765
+ )
+ Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
diff --git a/test/test_helpers.py b/test/test_helpers.py new file mode 100644 index 0000000..cc4c90e --- /dev/null +++ b/test/test_helpers.py @@ -0,0 +1,26 @@ +"""Tests for helpers functions """ + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +import os +import sys + +sys.path.append(os.path.join("ot", "helpers")) + +from openmp_helpers import get_openmp_flag, check_openmp_support # noqa +from pre_build_helpers import _get_compiler, compile_test_program # noqa + + +def test_helpers(): + + compiler = _get_compiler() + + get_openmp_flag(compiler) + + s = '#include <stdio.h>\n#include <stdlib.h>\n\nint main(void) {\n\tprintf("Hello world!\\n");\n\treturn 0;\n}' + output, _ = compile_test_program(s) + assert len(output) == 1 and output[0] == "Hello world!" + + check_openmp_support() diff --git a/test/test_optim.py b/test/test_optim.py index 87b0268..4efd9b1 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -8,7 +8,7 @@ import numpy as np import ot -def test_conditional_gradient(): +def test_conditional_gradient(nx): n_bins = 100 # nb bins np.random.seed(0) @@ -29,16 +29,26 @@ def test_conditional_gradient(): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + reg = 1e-1 G, log = ot.optim.cg(a, b, M, reg, f, df, verbose=True, log=True) + Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1)) - np.testing.assert_allclose(b, G.sum(0)) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1)) + np.testing.assert_allclose(b, Gb.sum(0)) -def test_conditional_gradient2(): - n = 1000 # nb samples +def test_conditional_gradient_itermax(nx): + n = 100 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -61,16 +71,27 @@ def test_conditional_gradient2(): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + reg = 1e-1 - G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000, + G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000, verbose=True, log=True) + Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, numItermaxEmd=10000, + verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1)) - np.testing.assert_allclose(b, G.sum(0)) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1)) + np.testing.assert_allclose(b, Gb.sum(0)) -def test_generalized_conditional_gradient(): +def test_generalized_conditional_gradient(nx): n_bins = 100 # nb bins np.random.seed(0) @@ -91,16 +112,76 @@ def test_generalized_conditional_gradient(): def df(G): return G + def fb(G): + return 0.5 * nx.sum(G ** 2) + reg1 = 1e-3 reg2 = 1e-1 + ab = nx.from_numpy(a) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M, type_as=ab) + G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True) + Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True) + Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(a, G.sum(1), atol=1e-05) - np.testing.assert_allclose(b, G.sum(0), atol=1e-05) + np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05) + np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05) def test_solve_1d_linesearch_quad_funct(): np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5) np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0) np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1) + + +def test_line_search_armijo(nx): + xk = np.array([[0.25, 0.25], [0.25, 0.25]]) + pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) + gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) + old_fval = -123 + # Should not throw an exception and return None for alpha + alpha, a, b = ot.optim.line_search_armijo( + lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval + ) + alpha_np, anp, bnp = ot.optim.line_search_armijo( + lambda x: 1, xk, pk, gfk, old_fval + ) + assert a == anp + assert b == bnp + assert alpha is None + + # check line search armijo + def f(x): + return nx.sum((x - 5.0) ** 2) + + def grad(x): + return 2 * (x - 5.0) + + xk = nx.from_numpy(np.array([[[-5.0, -5.0]]])) + pk = nx.from_numpy(np.array([[[100.0, 100.0]]])) + gfk = grad(xk) + old_fval = f(xk) + + # chech the case where the optimum is on the direction + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) + np.testing.assert_allclose(alpha, 0.1) + + # check the case where the direction is not far enough + pk = nx.from_numpy(np.array([[[3.0, 3.0]]])) + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0) + np.testing.assert_allclose(alpha, 1.0) + + # check the case where checking the wrong direction + alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval) + assert alpha <= 0 + + # check the case where the point is not a vector + xk = nx.from_numpy(np.array(-5.0)) + pk = nx.from_numpy(np.array(100.0)) + gfk = grad(xk) + old_fval = f(xk) + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) + np.testing.assert_allclose(alpha, 0.1) diff --git a/test/test_ot.py b/test/test_ot.py index b7306f6..92f26a7 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -8,13 +8,13 @@ import warnings import numpy as np import pytest -from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss +from ot.backend import torch -def test_emd_dimension_mismatch(): +def test_emd_dimension_and_mass_mismatch(): # test emd and emd2 for dimension mismatch n_samples = 100 n_features = 2 @@ -29,122 +29,125 @@ def test_emd_dimension_mismatch(): np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + b = a.copy() + a[0] = 100 + np.testing.assert_raises(AssertionError, ot.emd, a, b, M) -def test_emd_emd2(): - # test emd and emd2 for simple identity - n = 100 + +def test_emd_backends(nx): + n_samples = 100 + n_features = 2 rng = np.random.RandomState(0) - x = rng.randn(n, 2) - u = ot.utils.unif(n) + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) - M = ot.dist(x, x) + M = ot.dist(x, y) - G = ot.emd(u, u, M) + G = ot.emd(a, a, M) - # check G is identity - np.testing.assert_allclose(G, np.eye(n) / n) - # check constraints - np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn - np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) - w = ot.emd2(u, u, M) - # check loss=0 - np.testing.assert_allclose(w, 0) + Gb = ot.emd(ab, ab, Mb) + + np.allclose(G, nx.to_numpy(Gb)) -def test_emd_1d_emd2_1d(): - # test emd1d gives similar results as emd - n = 20 - m = 30 +def test_emd2_backends(nx): + n_samples = 100 + n_features = 2 rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - M = ot.dist(u, v, metric='sqeuclidean') + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) - G, log = ot.emd([], [], M, log=True) - wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) - wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) + val = ot.emd2(a, a, M) - # check loss is similar - np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, wass1d_emd2) + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) - # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) - np.testing.assert_allclose(wass_sp, wass1d_euc) + valb = ot.emd2(ab, ab, Mb) - # check constraints - np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) - np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) + np.allclose(val, nx.to_numpy(valb)) + + +def test_emd_emd2_types_devices(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ab = nx.from_numpy(a, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) - # check G is similar - np.testing.assert_allclose(G, G_1d) + Gb = ot.emd(ab, ab, Mb) - # check AssertionError is raised if called on non 1d arrays - u = np.random.randn(n, 2) - v = np.random.randn(m, 2) - with pytest.raises(AssertionError): - ot.emd_1d(u, v, [], []) + w = ot.emd2(ab, ab, Mb) + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, w) -def test_emd_1d_emd2_1d_with_weights(): - # test emd1d gives similar results as emd - n = 20 - m = 30 + +def test_emd2_gradients(): + n_samples = 100 + n_features = 2 rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - w_u = rng.uniform(0., 1., n) - w_u = w_u / w_u.sum() + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) - w_v = rng.uniform(0., 1., m) - w_v = w_v / w_v.sum() + M = ot.dist(x, y) - M = ot.dist(u, v, metric='sqeuclidean') + if torch: - G, log = ot.emd(w_u, w_v, M, log=True) - wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) - wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) + a1 = torch.tensor(a, requires_grad=True) + b1 = torch.tensor(a, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) - # check loss is similar - np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, wass1d_emd2) + val = ot.emd2(a1, b1, M1) - # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v) - np.testing.assert_allclose(wass_sp, wass1d_euc) + val.backward() - # check constraints - np.testing.assert_allclose(w_u, G.sum(1)) - np.testing.assert_allclose(w_v, G.sum(0)) + assert a1.shape == a1.grad.shape + assert b1.shape == b1.grad.shape + assert M1.shape == M1.grad.shape -def test_wass_1d(): - # test emd1d gives similar results as emd - n = 20 - m = 30 +def test_emd_emd2(): + # test emd and emd2 for simple identity + n = 100 rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - M = ot.dist(u, v, metric='sqeuclidean') + x = rng.randn(n, 2) + u = ot.utils.unif(n) - G, log = ot.emd([], [], M, log=True) - wass = log["cost"] + M = ot.dist(x, x) - wass1d = ot.wasserstein_1d(u, v, [], [], p=2.) + G = ot.emd(u, u, M) + + # check G is identity + np.testing.assert_allclose(G, np.eye(n) / n) + # check constraints + np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn - # check loss is similar - np.testing.assert_allclose(np.sqrt(wass), wass1d) + w = ot.emd2(u, u, M) + # check loss=0 + np.testing.assert_allclose(w, 0) def test_emd_empty(): @@ -291,17 +294,7 @@ def test_warnings(): print('Computing {} EMD '.format(1)) ot.emd(a, b, M, numItermax=1) assert "numItermax" in str(w[-1].message) - assert len(w) == 1 - a[0] = 100 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - assert len(w) == 2 - a[0] = -1 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - assert len(w) == 3 + #assert len(w) == 1 def test_dual_variables(): diff --git a/test/test_partial.py b/test/test_partial.py index 510e081..97c611b 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -51,10 +51,12 @@ def test_raise_errors(): ot.partial.partial_gromov_wasserstein(M, M, p, q, m=-1, log=True) with pytest.raises(ValueError): - ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, log=True) + ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, + log=True) with pytest.raises(ValueError): - ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, log=True) + ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, + log=True) def test_partial_wasserstein_lagrange(): @@ -102,7 +104,7 @@ def test_partial_wasserstein(): w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True) - # check constratints + # check constraints np.testing.assert_equal( w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein np.testing.assert_equal( @@ -125,11 +127,11 @@ def test_partial_wasserstein(): np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_equal( - G.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein + G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein np.testing.assert_equal( - G.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein + G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein np.testing.assert_allclose( np.sum(G), m, atol=1e-04) @@ -192,7 +194,7 @@ def test_partial_gromov_wasserstein(): 100, m=m, log=True) - # check constratints + # check constraints np.testing.assert_equal( res0.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein np.testing.assert_equal( diff --git a/test/test_regpath.py b/test/test_regpath.py new file mode 100644 index 0000000..967c27b --- /dev/null +++ b/test/test_regpath.py @@ -0,0 +1,64 @@ +"""Tests for module regularization path""" + +# Author: Haoran Wu <haoran.wu@univ-ubs.fr> +# +# License: MIT License + +import numpy as np +import ot + + +def test_fully_relaxed_path(): + + n_source = 50 # nb source samples (gaussian) + n_target = 40 # nb target samples (gaussian) + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + np.random.seed(0) + xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov) + + # source and target distributions + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + # loss matrix + M = ot.dist(xs, xt) + M /= M.max() + + t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8, + semi_relaxed=False) + + G = t.reshape((n_source, n_target)) + np.testing.assert_allclose(a, G.sum(1), atol=1e-05) + np.testing.assert_allclose(b, G.sum(0), atol=1e-05) + + +def test_semi_relaxed_path(): + + n_source = 50 # nb source samples (gaussian) + n_target = 40 # nb target samples (gaussian) + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + np.random.seed(0) + xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov) + + # source and target distributions + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + # loss matrix + M = ot.dist(xs, xt) + M /= M.max() + + t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8, + semi_relaxed=True) + + G = t.reshape((n_source, n_target)) + np.testing.assert_allclose(a, G.sum(1), atol=1e-05) + np.testing.assert_allclose(b, G.sum(0), atol=1e-10) diff --git a/test/test_sliced.py b/test/test_sliced.py new file mode 100644 index 0000000..245202c --- /dev/null +++ b/test/test_sliced.py @@ -0,0 +1,213 @@ +"""Tests for module sliced""" + +# Author: Adrien Corenflos <adrien.corenflos@aalto.fi> +# Nicolas Courty <ncourty@irisa.fr> +# +# License: MIT License + +import numpy as np +import pytest + +import ot +from ot.sliced import get_random_projections + + +def test_get_random_projections(): + rng = np.random.RandomState(0) + projections = get_random_projections(1000, 50, rng) + np.testing.assert_almost_equal(np.sum(projections ** 2, 0), 1.) + + +def test_sliced_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_sliced_bad_shapes(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) + + +def test_sliced_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, p=1, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[1] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + y = rng.randn(n, 2) + + res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) + assert res > 0. + + +def test_1d_sliced_equals_emd(): + n = 100 + m = 120 + rng = np.random.RandomState(0) + + x = rng.randn(n, 1) + a = rng.uniform(0, 1, n) + a /= a.sum() + y = rng.randn(m, 1) + u = ot.utils.unif(m) + res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42) + expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) + np.testing.assert_almost_equal(res ** 2, expected) + + +def test_max_sliced_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + res = ot.max_sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_max_sliced_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + y = rng.randn(n, 2) + + res, log = ot.max_sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + assert res > 0. + + +def test_sliced_backend(nx): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + n_projections = 20 + + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + + val0 = ot.sliced_wasserstein_distance(x, y, projections=P) + + val = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + val2 = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + + assert val > 0 + assert val == val2 + + valb = nx.to_numpy(ot.sliced_wasserstein_distance(xb, yb, projections=Pb)) + + assert np.allclose(val0, valb) + + +def test_sliced_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + yb = nx.from_numpy(y, type_as=tp) + Pb = nx.from_numpy(P, type_as=tp) + + valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) + + nx.assert_same_dtype_device(xb, valb) + + +def test_max_sliced_backend(nx): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + n_projections = 20 + + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + + val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P) + + val = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + val2 = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + + assert val > 0 + assert val == val2 + + valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)) + + assert np.allclose(val0, valb) + + +def test_max_sliced_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + yb = nx.from_numpy(y, type_as=tp) + Pb = nx.from_numpy(P, type_as=tp) + + valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) + + nx.assert_same_dtype_device(xb, valb) diff --git a/test/test_smooth.py b/test/test_smooth.py index 2afa4f8..31e0b2e 100644 --- a/test/test_smooth.py +++ b/test/test_smooth.py @@ -25,16 +25,16 @@ def test_smooth_ot_dual(): Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn - # kl regyularisation + # kl regularisation G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( @@ -60,16 +60,16 @@ def test_smooth_ot_semi_dual(): Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn - # kl regyularisation + # kl regularisation G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 155622c..736df32 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -30,7 +30,7 @@ import ot def test_stochastic_sag(): # test sag - n = 15 + n = 10 reg = 1 numItermax = 30000 rng = np.random.RandomState(0) @@ -43,11 +43,11 @@ def test_stochastic_sag(): G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag", numItermax=numItermax) - # check constratints + # check constraints np.testing.assert_allclose( - u, G.sum(1), atol=1e-04) # cf convergence sag + u, G.sum(1), atol=1e-03) # cf convergence sag np.testing.assert_allclose( - u, G.sum(0), atol=1e-04) # cf convergence sag + u, G.sum(0), atol=1e-03) # cf convergence sag ############################################################################# @@ -60,9 +60,9 @@ def test_stochastic_sag(): def test_stochastic_asgd(): # test asgd - n = 15 + n = 10 reg = 1 - numItermax = 100000 + numItermax = 10000 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -73,11 +73,11 @@ def test_stochastic_asgd(): G, log = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd", numItermax=numItermax, log=True) - # check constratints + # check constraints np.testing.assert_allclose( - u, G.sum(1), atol=1e-03) # cf convergence asgd + u, G.sum(1), atol=1e-02) # cf convergence asgd np.testing.assert_allclose( - u, G.sum(0), atol=1e-03) # cf convergence asgd + u, G.sum(0), atol=1e-02) # cf convergence asgd ############################################################################# @@ -90,9 +90,9 @@ def test_stochastic_asgd(): def test_sag_asgd_sinkhorn(): # test all algorithms - n = 15 + n = 10 reg = 1 - nb_iter = 100000 + nb_iter = 10000 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -105,19 +105,19 @@ def test_sag_asgd_sinkhorn(): numItermax=nb_iter) G_sinkhorn = ot.sinkhorn(u, u, M, reg) - # check constratints + # check constraints np.testing.assert_allclose( - G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03) + G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-02) np.testing.assert_allclose( - G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03) + G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-02) np.testing.assert_allclose( - G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) + G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-02) np.testing.assert_allclose( - G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) + G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-02) np.testing.assert_allclose( - G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag + G_sag, G_sinkhorn, atol=1e-02) # cf convergence sag np.testing.assert_allclose( - G_asgd, G_sinkhorn, atol=1e-03) # cf convergence asgd + G_asgd, G_sinkhorn, atol=1e-02) # cf convergence asgd ############################################################################# @@ -136,7 +136,7 @@ def test_stochastic_dual_sgd(): # test sgd n = 10 reg = 1 - numItermax = 15000 + numItermax = 5000 batch_size = 10 rng = np.random.RandomState(0) @@ -148,7 +148,7 @@ def test_stochastic_dual_sgd(): G, log = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size, numItermax=numItermax, log=True) - # check constratints + # check constraints np.testing.assert_allclose( u, G.sum(1), atol=1e-03) # cf convergence sgd np.testing.assert_allclose( @@ -167,7 +167,7 @@ def test_dual_sgd_sinkhorn(): # test all dual algorithms n = 10 reg = 1 - nb_iter = 15000 + nb_iter = 5000 batch_size = 10 rng = np.random.RandomState(0) @@ -181,13 +181,13 @@ def test_dual_sgd_sinkhorn(): G_sinkhorn = ot.sinkhorn(u, u, M, reg) - # check constratints + # check constraints np.testing.assert_allclose( - G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) + G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-02) np.testing.assert_allclose( - G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) + G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-02) np.testing.assert_allclose( - G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd + G_sgd, G_sinkhorn, atol=1e-02) # cf convergence sgd # Test gaussian n = 30 @@ -206,7 +206,7 @@ def test_dual_sgd_sinkhorn(): G_sinkhorn = ot.sinkhorn(a, b, M, reg) - # check constratints + # check constraints np.testing.assert_allclose( G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) np.testing.assert_allclose( diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index dfeaad9..e8349d1 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -115,7 +115,8 @@ def test_stabilized_vs_sinkhorn(): G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon, method="sinkhorn_stabilized", reg_m=reg_m, - log=True) + log=True, + verbose=True) G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, method="sinkhorn", log=True) @@ -138,7 +139,7 @@ def test_unbalanced_barycenter(method): reg_m = 1. q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method=method, log=True) + method=method, log=True, verbose=True) # check fixed point equations fi = reg_m / (reg_m + epsilon) logA = np.log(A + 1e-16) @@ -173,6 +174,7 @@ def test_barycenter_stabilized_vs_sinkhorn(): reg_m=reg_m, log=True, tau=100, method="sinkhorn_stabilized", + verbose=True ) q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, method="sinkhorn", @@ -182,6 +184,33 @@ def test_barycenter_stabilized_vs_sinkhorn(): q, qstable, atol=1e-05) +def test_wrong_method(): + + n = 10 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) * 1.5 + + M = ot.dist(x, x) + epsilon = 1. + reg_m = 1. + + with pytest.raises(ValueError): + ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + reg_m=reg_m, + method='badmethod', + log=True, + verbose=True) + with pytest.raises(ValueError): + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, + method='badmethod', + verbose=True) + + def test_implemented_methods(): IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] diff --git a/test/test_utils.py b/test/test_utils.py index db9cda6..40f4e49 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,10 +4,41 @@ # # License: MIT License - import ot import numpy as np import sys +import pytest + + +def test_proj_simplex(nx): + n = 10 + rng = np.random.RandomState(0) + + # test on matrix when projection is done on axis 0 + x = rng.randn(n, 2) + x1 = nx.from_numpy(x) + + # all projections should sum to 1 + proj = ot.utils.proj_simplex(x1) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + + # all projections should sum to 3 + proj = ot.utils.proj_simplex(x1, 3) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = 3 * np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + + # tets on vector + x = rng.randn(n) + x1 = nx.from_numpy(x) + + # all projections should sum to 1 + proj = ot.utils.proj_simplex(x1) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) def test_parmap(): @@ -45,8 +76,8 @@ def test_tic_toc(): def test_kernel(): n = 100 - - x = np.random.randn(n, 2) + rng = np.random.RandomState(0) + x = rng.randn(n, 2) K = ot.utils.kernel(x, x) @@ -67,7 +98,8 @@ def test_dist(): n = 100 - x = np.random.randn(n, 2) + rng = np.random.RandomState(0) + x = rng.randn(n, 2) D = np.zeros((n, n)) for i in range(n): @@ -77,9 +109,31 @@ def test_dist(): D2 = ot.dist(x, x) D3 = ot.dist(x) + D4 = ot.dist(x, x, metric='minkowski', p=2) + + assert D4[0, 1] == D4[1, 0] + # dist shoul return squared euclidean - np.testing.assert_allclose(D, D2) - np.testing.assert_allclose(D, D3) + np.testing.assert_allclose(D, D2, atol=1e-14) + np.testing.assert_allclose(D, D3, atol=1e-14) + + +def test_dist_backends(nx): + + n = 100 + rng = np.random.RandomState(0) + x = rng.randn(n, 2) + x1 = nx.from_numpy(x) + + lst_metric = ['euclidean', 'sqeuclidean'] + + for metric in lst_metric: + + D = ot.dist(x, x, metric=metric) + D1 = ot.dist(x1, x1, metric=metric) + + # low atol because jax forces float32 + np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5) def test_dist0(): @@ -95,9 +149,11 @@ def test_dots(): n1, n2, n3, n4 = 100, 50, 200, 100 - A = np.random.randn(n1, n2) - B = np.random.randn(n2, n3) - C = np.random.randn(n3, n4) + rng = np.random.RandomState(0) + + A = rng.randn(n1, n2) + B = rng.randn(n2, n3) + C = rng.randn(n3, n4) X1 = ot.utils.dots(A, B, C) @@ -169,6 +225,13 @@ def test_deprecated_func(): class Class(): pass + with pytest.warns(DeprecationWarning): + fun() + + with pytest.warns(DeprecationWarning): + cl = Class() + print(cl) + if sys.version_info < (3, 5): print('Not tested') else: @@ -199,4 +262,7 @@ def test_BaseEstimator(): params['first'] = 'spam again' cl.set_params(**params) + with pytest.raises(ValueError): + cl.set_params(bibi=10) + assert cl.first == 'spam again' |