diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-06-01 10:10:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-01 10:10:54 +0200 |
commit | 184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch) | |
tree | 483a7274c91030fd644de49b03a5fad04af9deba /test/test_ot.py | |
parent | 1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff) |
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends
* stat sets on functions
* proper import
* install recent torch on windows
* install recent torch on windows
* now testing all functions in backedn
* add jax backedn
* clenaup windowds
* proper convert for jax backedn
* pep8
* try again windows tests
* test jax conversion
* try proper widows tests
* emd fuction ses backedn
* better test partial OT
* proper tests to_numpy and teplate Backend
* pep8
* pep8 x2
* feaking sinkhorn works with torch
* sinkhorn2 compatible
* working ot.emd2
* important detach
* it should work
* jax autodiff emd
* pep8
* no tast same for jax
* new independat tests per backedn
* freaking pep8
* add tests for gradients
* deprecate ot.gpu
* worging dist function
* working dist
* dist done in backedn
* not in
* remove indexing
* change accuacy for jax
* first pull backend
* projection simplex
* projection simplex
* projection simplex
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8
* add backedn discusion to quickstart guide
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8 + better doc
* proper links
* corect doctest
* big debug documentation
* doctest again
* doctest again bis
* doctest again ter (last one or i kill myself)
* backend test + doc proj simplex
* correction test_utils
* correction test_utils
* correction cumsum
* correction flip
* correction flip v2
* more debug
* more debug
* more debug + pep8
* pep8
* argh
* proj_simplex
* backedn works for sort
* proj simplex
* jax sucks
* update doc
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/lp/__init__.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* begin comment alex
* comment alex part 2
* optimize test gromov
* proj_simplex on vectors
* add awesome gradient decsnt example on the weights
* pep98 of course
* proof read example by alex
* pep8 again
* encoding oos in translation
* correct legend
Co-authored-by: Nicolas Courty <ncourty@irisa.fr>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 91 |
1 files changed, 79 insertions, 12 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index f45e4c9..3e953dc 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,9 +12,12 @@ from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss +from ot.backend import get_backend_list, torch +backend_list = get_backend_list() -def test_emd_dimension_mismatch(): + +def test_emd_dimension_and_mass_mismatch(): # test emd and emd2 for dimension mismatch n_samples = 100 n_features = 2 @@ -29,6 +32,80 @@ def test_emd_dimension_mismatch(): np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + b = a.copy() + a[0] = 100 + np.testing.assert_raises(AssertionError, ot.emd, a, b, M) + + +@pytest.mark.parametrize('nx', backend_list) +def test_emd_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.emd(a, a, M) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + Gb = ot.emd(ab, ab, Mb) + + np.allclose(G, nx.to_numpy(Gb)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_emd2_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + val = ot.emd2(a, a, M) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + valb = ot.emd2(ab, ab, Mb) + + np.allclose(val, nx.to_numpy(valb)) + + +def test_emd2_gradients(): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + if torch: + + a1 = torch.tensor(a, requires_grad=True) + b1 = torch.tensor(a, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.emd2(a1, b1, M1) + + val.backward() + + assert a1.shape == a1.grad.shape + assert b1.shape == b1.grad.shape + assert M1.shape == M1.grad.shape + def test_emd_emd2(): # test emd and emd2 for simple identity @@ -83,7 +160,7 @@ def test_emd_1d_emd2_1d(): np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) # check G is similar - np.testing.assert_allclose(G, G_1d) + np.testing.assert_allclose(G, G_1d, atol=1e-15) # check AssertionError is raised if called on non 1d arrays u = np.random.randn(n, 2) @@ -292,16 +369,6 @@ def test_warnings(): ot.emd(a, b, M, numItermax=1) assert "numItermax" in str(w[-1].message) #assert len(w) == 1 - a[0] = 100 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - #assert len(w) == 2 - a[0] = -1 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - #assert len(w) == 3 def test_dual_variables(): |