summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-06-01 10:10:54 +0200
committerGitHub <noreply@github.com>2021-06-01 10:10:54 +0200
commit184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch)
tree483a7274c91030fd644de49b03a5fad04af9deba /test
parent1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff)
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty <ncourty@irisa.fr> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'test')
-rw-r--r--test/test_backend.py364
-rw-r--r--test/test_bregman.py74
-rw-r--r--test/test_gromov.py10
-rw-r--r--test/test_ot.py91
-rwxr-xr-xtest/test_partial.py4
-rw-r--r--test/test_utils.py76
6 files changed, 591 insertions, 28 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
new file mode 100644
index 0000000..bc5b00c
--- /dev/null
+++ b/test/test_backend.py
@@ -0,0 +1,364 @@
+"""Tests for backend module """
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+import ot
+import ot.backend
+from ot.backend import torch, jax
+
+import pytest
+
+import numpy as np
+from numpy.testing import assert_array_almost_equal_nulp
+
+from ot.backend import get_backend, get_backend_list, to_numpy
+
+
+backend_list = get_backend_list()
+
+
+def test_get_backend_list():
+
+ lst = get_backend_list()
+
+ assert len(lst) > 0
+ assert isinstance(lst[0], ot.backend.NumpyBackend)
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_to_numpy(nx):
+
+ v = nx.zeros(10)
+ M = nx.ones((10, 10))
+
+ v2 = to_numpy(v)
+ assert isinstance(v2, np.ndarray)
+
+ v2, M2 = to_numpy(v, M)
+ assert isinstance(M2, np.ndarray)
+
+
+def test_get_backend():
+
+ A = np.zeros((3, 2))
+ B = np.zeros((3, 1))
+
+ nx = get_backend(A)
+ assert nx.__name__ == 'numpy'
+
+ nx = get_backend(A, B)
+ assert nx.__name__ == 'numpy'
+
+ # error if no parameters
+ with pytest.raises(ValueError):
+ get_backend()
+
+ # error if unknown types
+ with pytest.raises(ValueError):
+ get_backend(1, 2.0)
+
+ # test torch
+ if torch:
+
+ A2 = torch.from_numpy(A)
+ B2 = torch.from_numpy(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'torch'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'torch'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
+ if jax:
+
+ A2 = jax.numpy.array(A)
+ B2 = jax.numpy.array(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'jax'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'jax'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_convert_between_backends(nx):
+
+ A = np.zeros((3, 2))
+ B = np.zeros((3, 1))
+
+ A2 = nx.from_numpy(A)
+ B2 = nx.from_numpy(B)
+
+ assert isinstance(A2, nx.__type__)
+ assert isinstance(B2, nx.__type__)
+
+ nx2 = get_backend(A2, B2)
+
+ assert nx2.__name__ == nx.__name__
+
+ assert_array_almost_equal_nulp(nx.to_numpy(A2), A)
+ assert_array_almost_equal_nulp(nx.to_numpy(B2), B)
+
+
+def test_empty_backend():
+
+ rnd = np.random.RandomState(0)
+ M = rnd.randn(10, 3)
+ v = rnd.randn(3)
+
+ nx = ot.backend.Backend()
+
+ with pytest.raises(NotImplementedError):
+ nx.from_numpy(M)
+ with pytest.raises(NotImplementedError):
+ nx.to_numpy(M)
+ with pytest.raises(NotImplementedError):
+ nx.set_gradients(0, 0, 0)
+ with pytest.raises(NotImplementedError):
+ nx.zeros((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.ones((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.arange(10, 1, 2)
+ with pytest.raises(NotImplementedError):
+ nx.full((10, 3), 3.14)
+ with pytest.raises(NotImplementedError):
+ nx.eye((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.sum(M)
+ with pytest.raises(NotImplementedError):
+ nx.cumsum(M)
+ with pytest.raises(NotImplementedError):
+ nx.max(M)
+ with pytest.raises(NotImplementedError):
+ nx.min(M)
+ with pytest.raises(NotImplementedError):
+ nx.maximum(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.minimum(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.abs(M)
+ with pytest.raises(NotImplementedError):
+ nx.log(M)
+ with pytest.raises(NotImplementedError):
+ nx.exp(M)
+ with pytest.raises(NotImplementedError):
+ nx.sqrt(M)
+ with pytest.raises(NotImplementedError):
+ nx.dot(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.norm(M)
+ with pytest.raises(NotImplementedError):
+ nx.exp(M)
+ with pytest.raises(NotImplementedError):
+ nx.any(M)
+ with pytest.raises(NotImplementedError):
+ nx.isnan(M)
+ with pytest.raises(NotImplementedError):
+ nx.isinf(M)
+ with pytest.raises(NotImplementedError):
+ nx.einsum('ij->i', M)
+ with pytest.raises(NotImplementedError):
+ nx.sort(M)
+ with pytest.raises(NotImplementedError):
+ nx.argsort(M)
+ with pytest.raises(NotImplementedError):
+ nx.flip(M)
+
+
+@pytest.mark.parametrize('backend', backend_list)
+def test_func_backends(backend):
+
+ rnd = np.random.RandomState(0)
+ M = rnd.randn(10, 3)
+ v = rnd.randn(3)
+ val = np.array([1.0])
+
+ lst_tot = []
+
+ for nx in [ot.backend.NumpyBackend(), backend]:
+
+ print('Backend: ', nx.__name__)
+
+ lst_b = []
+ lst_name = []
+
+ Mb = nx.from_numpy(M)
+ vb = nx.from_numpy(v)
+ val = nx.from_numpy(val)
+
+ A = nx.set_gradients(val, v, v)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('set_gradients')
+
+ A = nx.zeros((10, 3))
+ A = nx.zeros((10, 3), type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('zeros')
+
+ A = nx.ones((10, 3))
+ A = nx.ones((10, 3), type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('ones')
+
+ A = nx.arange(10, 1, 2)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('arange')
+
+ A = nx.full((10, 3), 3.14)
+ A = nx.full((10, 3), 3.14, type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('full')
+
+ A = nx.eye(10, 3)
+ A = nx.eye(10, 3, type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('eye')
+
+ A = nx.sum(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sum')
+
+ A = nx.sum(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sum(axis)')
+
+ A = nx.cumsum(Mb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('cumsum(axis)')
+
+ A = nx.max(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('max')
+
+ A = nx.max(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('max(axis)')
+
+ A = nx.min(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('min')
+
+ A = nx.min(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('min(axis)')
+
+ A = nx.maximum(vb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('maximum')
+
+ A = nx.minimum(vb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('minimum')
+
+ A = nx.abs(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('abs')
+
+ A = nx.log(A)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('log')
+
+ A = nx.exp(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('exp')
+
+ A = nx.sqrt(nx.abs(Mb))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sqrt')
+
+ A = nx.dot(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(v,v)')
+
+ A = nx.dot(Mb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(M,v)')
+
+ A = nx.dot(Mb, Mb.T)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(M,M)')
+
+ A = nx.norm(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('norm')
+
+ A = nx.any(vb > 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('any')
+
+ A = nx.isnan(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('isnan')
+
+ A = nx.isinf(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('isinf')
+
+ A = nx.einsum('ij->i', Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('einsum(ij->i)')
+
+ A = nx.einsum('ij,j->i', Mb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('nx.einsum(ij,j->i)')
+
+ A = nx.einsum('ij->i', Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('nx.einsum(ij->i)')
+
+ A = nx.sort(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sort')
+
+ A = nx.argsort(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argsort')
+
+ A = nx.flip(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('flip')
+
+ lst_tot.append(lst_b)
+
+ lst_np = lst_tot[0]
+ lst_b = lst_tot[1]
+
+ for a1, a2, name in zip(lst_np, lst_b, lst_name):
+ if not np.allclose(a1, a2):
+ print('Assert fail on: ', name)
+ assert np.allclose(a1, a2, atol=1e-7)
+
+
+def test_gradients_backends():
+
+ rnd = np.random.RandomState(0)
+ v = rnd.randn(10)
+ c = rnd.randn(1)
+
+ if torch:
+
+ nx = ot.backend.TorchBackend()
+
+ v2 = torch.tensor(v, requires_grad=True)
+ c2 = torch.tensor(c, requires_grad=True)
+
+ val = c2 * torch.sum(v2 * v2)
+
+ val2 = nx.set_gradients(val, (v2, c2), (v2, c2))
+
+ val2.backward()
+
+ assert torch.equal(v2.grad, v2)
+ assert torch.equal(c2.grad, c2)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 1ebd21f..7c5162a 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -9,6 +9,10 @@ import numpy as np
import pytest
import ot
+from ot.backend import get_backend_list
+from ot.backend import torch
+
+backend_list = get_backend_list()
def test_sinkhorn():
@@ -30,6 +34,76 @@ def test_sinkhorn():
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+@pytest.mark.parametrize('nx', backend_list)
+def test_sinkhorn_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn(ab, ab, Mb, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_sinkhorn2_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn2(ab, ab, Mb, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+def test_sinkhorn2_gradients():
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ if torch:
+
+ a1 = torch.tensor(a, requires_grad=True)
+ b1 = torch.tensor(a, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.sinkhorn2(a1, b1, M1, 1)
+
+ val.backward()
+
+ assert a1.shape == a1.grad.shape
+ assert b1.shape == b1.grad.shape
+ assert M1.shape == M1.grad.shape
+
+
def test_sinkhorn_empty():
# test sinkhorn
n = 100
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 43da9fc..81138ca 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -181,7 +181,7 @@ def test_fgw():
M = ot.dist(ys, yt)
M /= M.max()
- G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5)
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
# check constratints
np.testing.assert_allclose(
@@ -242,9 +242,9 @@ def test_fgw_barycenter():
init_X = np.random.randn(n_samples, ys.shape[1])
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=True, init_X=init_X,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
+ X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_X,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3, log=True)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
diff --git a/test/test_ot.py b/test/test_ot.py
index f45e4c9..3e953dc 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -12,9 +12,12 @@ from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
+from ot.backend import get_backend_list, torch
+backend_list = get_backend_list()
-def test_emd_dimension_mismatch():
+
+def test_emd_dimension_and_mass_mismatch():
# test emd and emd2 for dimension mismatch
n_samples = 100
n_features = 2
@@ -29,6 +32,80 @@ def test_emd_dimension_mismatch():
np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
+ b = a.copy()
+ a[0] = 100
+ np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_emd_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.emd(a, a, M)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ Gb = ot.emd(ab, ab, Mb)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_emd2_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ val = ot.emd2(a, a, M)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ valb = ot.emd2(ab, ab, Mb)
+
+ np.allclose(val, nx.to_numpy(valb))
+
+
+def test_emd2_gradients():
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ if torch:
+
+ a1 = torch.tensor(a, requires_grad=True)
+ b1 = torch.tensor(a, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.emd2(a1, b1, M1)
+
+ val.backward()
+
+ assert a1.shape == a1.grad.shape
+ assert b1.shape == b1.grad.shape
+ assert M1.shape == M1.grad.shape
+
def test_emd_emd2():
# test emd and emd2 for simple identity
@@ -83,7 +160,7 @@ def test_emd_1d_emd2_1d():
np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
# check G is similar
- np.testing.assert_allclose(G, G_1d)
+ np.testing.assert_allclose(G, G_1d, atol=1e-15)
# check AssertionError is raised if called on non 1d arrays
u = np.random.randn(n, 2)
@@ -292,16 +369,6 @@ def test_warnings():
ot.emd(a, b, M, numItermax=1)
assert "numItermax" in str(w[-1].message)
#assert len(w) == 1
- a[0] = 100
- print('Computing {} EMD '.format(2))
- ot.emd(a, b, M)
- assert "infeasible" in str(w[-1].message)
- #assert len(w) == 2
- a[0] = -1
- print('Computing {} EMD '.format(2))
- ot.emd(a, b, M)
- assert "infeasible" in str(w[-1].message)
- #assert len(w) == 3
def test_dual_variables():
diff --git a/test/test_partial.py b/test/test_partial.py
index 121f345..3571e2a 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -129,9 +129,9 @@ def test_partial_wasserstein():
# check constratints
np.testing.assert_equal(
- G.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
+ G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
- G.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
+ G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
np.testing.assert_allclose(
np.sum(G), m, atol=1e-04)
diff --git a/test/test_utils.py b/test/test_utils.py
index db9cda6..76b1faa 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -4,11 +4,47 @@
#
# License: MIT License
-
+import pytest
import ot
import numpy as np
import sys
+from ot.backend import get_backend_list
+
+backend_list = get_backend_list()
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_proj_simplex(nx):
+ n = 10
+ rng = np.random.RandomState(0)
+
+ # test on matrix when projection is done on axis 0
+ x = rng.randn(n, 2)
+ x1 = nx.from_numpy(x)
+
+ # all projections should sum to 1
+ proj = ot.utils.proj_simplex(x1)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
+ # all projections should sum to 3
+ proj = ot.utils.proj_simplex(x1, 3)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = 3 * np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
+ # tets on vector
+ x = rng.randn(n)
+ x1 = nx.from_numpy(x)
+
+ # all projections should sum to 1
+ proj = ot.utils.proj_simplex(x1)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
def test_parmap():
@@ -45,8 +81,8 @@ def test_tic_toc():
def test_kernel():
n = 100
-
- x = np.random.randn(n, 2)
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
K = ot.utils.kernel(x, x)
@@ -67,7 +103,8 @@ def test_dist():
n = 100
- x = np.random.randn(n, 2)
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
D = np.zeros((n, n))
for i in range(n):
@@ -78,8 +115,27 @@ def test_dist():
D3 = ot.dist(x)
# dist shoul return squared euclidean
- np.testing.assert_allclose(D, D2)
- np.testing.assert_allclose(D, D3)
+ np.testing.assert_allclose(D, D2, atol=1e-14)
+ np.testing.assert_allclose(D, D3, atol=1e-14)
+
+
+@ pytest.mark.parametrize('nx', backend_list)
+def test_dist_backends(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ x1 = nx.from_numpy(x)
+
+ lst_metric = ['euclidean', 'sqeuclidean']
+
+ for metric in lst_metric:
+
+ D = ot.dist(x, x, metric=metric)
+ D1 = ot.dist(x1, x1, metric=metric)
+
+ # low atol because jax forces float32
+ np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)
def test_dist0():
@@ -95,9 +151,11 @@ def test_dots():
n1, n2, n3, n4 = 100, 50, 200, 100
- A = np.random.randn(n1, n2)
- B = np.random.randn(n2, n3)
- C = np.random.randn(n3, n4)
+ rng = np.random.RandomState(0)
+
+ A = rng.randn(n1, n2)
+ B = rng.randn(n2, n3)
+ C = rng.randn(n3, n4)
X1 = ot.utils.dots(A, B, C)