summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorNicolas Courty <ncourty@irisa.fr>2021-11-02 14:19:57 +0100
committerGitHub <noreply@github.com>2021-11-02 14:19:57 +0100
commit6775a527f9d3c801f8cdd805d8f205b6a75551b9 (patch)
treec0ed5a7c297b4003688fec52d46f918ea0086a7d /test
parenta335324d008e8982be61d7ace937815a2bfa98f9 (diff)
[MRG] Sliced and 1D Wasserstein distances : backend versions (#256)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend * new backend functions for sliced * small indent pb * Optimized backendversion of sliced W * error in sliced W * after master merge * error sliced * error sliced * pep8 * test_sliced pep8 * doctest + precision for sliced * doctest * type win test_backend gather * type win test_backend gather * Update sliced.py change argument of padding pad_width * Update backend.py update redefinition * Update backend.py pep8 * Update backend.py pep 8 again.... * pep8 * build docs * emd2_1D example * refectoring emd_1d and variants * remove unused previous wasserstein_1d * pep8 * upate example * move stuff * tesys should work + implemù random backend * test random generayor functions * correction * better random generation * update sliced * update sliced * proper tests sliced * max sliced * chae file nam * add stuff * example sliced flow and barycenter * correct typo + update readme * exemple sliced flow done * pep8 * solver1d works * pep8 Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'test')
-rw-r--r--test/test_1d_solver.py85
-rw-r--r--test/test_backend.py36
-rw-r--r--test/test_ot.py57
-rw-r--r--test/test_sliced.py90
-rw-r--r--test/test_utils.py2
5 files changed, 210 insertions, 60 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
new file mode 100644
index 0000000..2c470c2
--- /dev/null
+++ b/test/test_1d_solver.py
@@ -0,0 +1,85 @@
+"""Tests for module 1d Wasserstein solver"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import pytest
+
+import ot
+from ot.lp import wasserstein_1d
+
+from ot.backend import get_backend_list
+from scipy.stats import wasserstein_distance
+
+backend_list = get_backend_list()
+
+
+def test_emd_1d_emd2_1d_with_weights():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ w_u = rng.uniform(0., 1., n)
+ w_u = w_u / w_u.sum()
+
+ w_v = rng.uniform(0., 1., m)
+ w_v = w_v / w_v.sum()
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd(w_u, w_v, M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(w_u, G.sum(1))
+ np.testing.assert_allclose(w_v, G.sum(0))
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_wasserstein_1d(nx):
+ from scipy.stats import wasserstein_distance
+
+ rng = np.random.RandomState(0)
+
+ n = 100
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+
+ # test 1 : wasserstein_1d should be close to scipy W_1 implementation
+ np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1),
+ wasserstein_distance(x, x, rho_u, rho_v))
+
+ # test 2 : wasserstein_1d should be close to one when only translating the support
+ np.testing.assert_almost_equal(wasserstein_1d(xb, xb + 1, p=2),
+ 1.)
+
+ # test 3 : arrays test
+ X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1)
+ Xb = nx.from_numpy(X)
+ res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2)
+ np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4)
diff --git a/test/test_backend.py b/test/test_backend.py
index 0f11ace..1832b91 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -208,6 +208,11 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.reshape(M, (5, 3, 2))
with pytest.raises(NotImplementedError):
+ nx.seed(42)
+ with pytest.raises(NotImplementedError):
+ nx.rand()
+ with pytest.raises(NotImplementedError):
+ nx.randn()
nx.coo_matrix(M, M, M)
with pytest.raises(NotImplementedError):
nx.issparse(M)
@@ -248,6 +253,7 @@ def test_func_backends(nx):
Mb = nx.from_numpy(M)
vb = nx.from_numpy(v)
+
val = nx.from_numpy(val)
sp_rowb = nx.from_numpy(sp_row)
@@ -255,6 +261,7 @@ def test_func_backends(nx):
sp_datab = nx.from_numpy(sp_data)
A = nx.set_gradients(val, v, v)
+
lst_b.append(nx.to_numpy(A))
lst_name.append('set_gradients')
@@ -505,6 +512,35 @@ def test_func_backends(nx):
assert np.allclose(a1, a2, atol=1e-7)
+def test_random_backends(nx):
+
+ tmp_u = nx.rand()
+
+ assert tmp_u < 1
+
+ tmp_n = nx.randn()
+
+ nx.seed(0)
+ M1 = nx.to_numpy(nx.rand(5, 2))
+ nx.seed(0)
+ M2 = nx.to_numpy(nx.rand(5, 2, type_as=tmp_n))
+
+ assert np.all(M1 >= 0)
+ assert np.all(M1 < 1)
+ assert M1.shape == (5, 2)
+ assert np.allclose(M1, M2)
+
+ nx.seed(0)
+ M1 = nx.to_numpy(nx.randn(5, 2))
+ nx.seed(0)
+ M2 = nx.to_numpy(nx.randn(5, 2, type_as=tmp_u))
+
+ nx.seed(42)
+ v1 = nx.randn()
+ v2 = nx.randn()
+ assert v1 != v2
+
+
def test_gradients_backends():
rnd = np.random.RandomState(0)
diff --git a/test/test_ot.py b/test/test_ot.py
index 4dfc510..5bfde1d 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -8,11 +8,11 @@ import warnings
import numpy as np
import pytest
-from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
from ot.backend import torch
+from scipy.stats import wasserstein_distance
def test_emd_dimension_and_mass_mismatch():
@@ -165,61 +165,6 @@ def test_emd_1d_emd2_1d():
ot.emd_1d(u, v, [], [])
-def test_emd_1d_emd2_1d_with_weights():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
- rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
-
- w_u = rng.uniform(0., 1., n)
- w_u = w_u / w_u.sum()
-
- w_v = rng.uniform(0., 1., m)
- w_v = w_v / w_v.sum()
-
- M = ot.dist(u, v, metric='sqeuclidean')
-
- G, log = ot.emd(w_u, w_v, M, log=True)
- wass = log["cost"]
- G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
- wass1d = log["cost"]
- wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
- wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
-
- # check loss is similar
- np.testing.assert_allclose(wass, wass1d)
- np.testing.assert_allclose(wass, wass1d_emd2)
-
- # check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
- np.testing.assert_allclose(wass_sp, wass1d_euc)
-
- # check constraints
- np.testing.assert_allclose(w_u, G.sum(1))
- np.testing.assert_allclose(w_v, G.sum(0))
-
-
-def test_wass_1d():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
- rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
-
- M = ot.dist(u, v, metric='sqeuclidean')
-
- G, log = ot.emd([], [], M, log=True)
- wass = log["cost"]
-
- wass1d = ot.wasserstein_1d(u, v, [], [], p=2.)
-
- # check loss is similar
- np.testing.assert_allclose(np.sqrt(wass), wass1d)
-
-
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100
diff --git a/test/test_sliced.py b/test/test_sliced.py
index a07d975..0bd74ec 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -1,6 +1,7 @@
"""Tests for module sliced"""
# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
@@ -14,7 +15,7 @@ from ot.sliced import get_random_projections
def test_get_random_projections():
rng = np.random.RandomState(0)
projections = get_random_projections(1000, 50, rng)
- np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.)
+ np.testing.assert_almost_equal(np.sum(projections ** 2, 0), 1.)
def test_sliced_same_dist():
@@ -48,12 +49,12 @@ def test_sliced_log():
y = rng.randn(n, 4)
u = ot.utils.unif(n)
- res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True)
+ res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, p=1, seed=rng, log=True)
assert len(log) == 2
projections = log["projections"]
projected_emds = log["projected_emds"]
- assert len(projections) == len(projected_emds) == 10
+ assert projections.shape[1] == len(projected_emds) == 10
for emd in projected_emds:
assert emd > 0
@@ -83,3 +84,86 @@ def test_1d_sliced_equals_emd():
res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42)
expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u)
np.testing.assert_almost_equal(res ** 2, expected)
+
+
+def test_max_sliced_same_dist():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ res = ot.max_sliced_wasserstein_distance(x, x, u, u, 10, seed=rng)
+ np.testing.assert_almost_equal(res, 0.)
+
+
+def test_max_sliced_different_dists():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ y = rng.randn(n, 2)
+
+ res, log = ot.max_sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True)
+ assert res > 0.
+
+
+def test_sliced_backend(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ n_projections = 20
+
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+
+ val0 = ot.sliced_wasserstein_distance(x, y, projections=P)
+
+ val = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+ val2 = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+
+ assert val > 0
+ assert val == val2
+
+ valb = nx.to_numpy(ot.sliced_wasserstein_distance(xb, yb, projections=Pb))
+
+ assert np.allclose(val0, valb)
+
+
+def test_max_sliced_backend(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ n_projections = 20
+
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+
+ val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P)
+
+ val = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+ val2 = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+
+ assert val > 0
+ assert val == val2
+
+ valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb))
+
+ assert np.allclose(val0, valb)
diff --git a/test/test_utils.py b/test/test_utils.py
index 0650ce2..40f4e49 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -109,7 +109,7 @@ def test_dist():
D2 = ot.dist(x, x)
D3 = ot.dist(x)
- D4 = ot.dist(x, x, metric='minkowski', p=0.5)
+ D4 = ot.dist(x, x, metric='minkowski', p=2)
assert D4[0, 1] == D4[1, 0]