summaryrefslogtreecommitdiff
path: root/test/test_utils.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-06-01 10:10:54 +0200
committerGitHub <noreply@github.com>2021-06-01 10:10:54 +0200
commit184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch)
tree483a7274c91030fd644de49b03a5fad04af9deba /test/test_utils.py
parent1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff)
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty <ncourty@irisa.fr> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'test/test_utils.py')
-rw-r--r--test/test_utils.py76
1 files changed, 67 insertions, 9 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index db9cda6..76b1faa 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -4,11 +4,47 @@
#
# License: MIT License
-
+import pytest
import ot
import numpy as np
import sys
+from ot.backend import get_backend_list
+
+backend_list = get_backend_list()
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_proj_simplex(nx):
+ n = 10
+ rng = np.random.RandomState(0)
+
+ # test on matrix when projection is done on axis 0
+ x = rng.randn(n, 2)
+ x1 = nx.from_numpy(x)
+
+ # all projections should sum to 1
+ proj = ot.utils.proj_simplex(x1)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
+ # all projections should sum to 3
+ proj = ot.utils.proj_simplex(x1, 3)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = 3 * np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
+ # tets on vector
+ x = rng.randn(n)
+ x1 = nx.from_numpy(x)
+
+ # all projections should sum to 1
+ proj = ot.utils.proj_simplex(x1)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
def test_parmap():
@@ -45,8 +81,8 @@ def test_tic_toc():
def test_kernel():
n = 100
-
- x = np.random.randn(n, 2)
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
K = ot.utils.kernel(x, x)
@@ -67,7 +103,8 @@ def test_dist():
n = 100
- x = np.random.randn(n, 2)
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
D = np.zeros((n, n))
for i in range(n):
@@ -78,8 +115,27 @@ def test_dist():
D3 = ot.dist(x)
# dist shoul return squared euclidean
- np.testing.assert_allclose(D, D2)
- np.testing.assert_allclose(D, D3)
+ np.testing.assert_allclose(D, D2, atol=1e-14)
+ np.testing.assert_allclose(D, D3, atol=1e-14)
+
+
+@ pytest.mark.parametrize('nx', backend_list)
+def test_dist_backends(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ x1 = nx.from_numpy(x)
+
+ lst_metric = ['euclidean', 'sqeuclidean']
+
+ for metric in lst_metric:
+
+ D = ot.dist(x, x, metric=metric)
+ D1 = ot.dist(x1, x1, metric=metric)
+
+ # low atol because jax forces float32
+ np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)
def test_dist0():
@@ -95,9 +151,11 @@ def test_dots():
n1, n2, n3, n4 = 100, 50, 200, 100
- A = np.random.randn(n1, n2)
- B = np.random.randn(n2, n3)
- C = np.random.randn(n3, n4)
+ rng = np.random.RandomState(0)
+
+ A = rng.randn(n1, n2)
+ B = rng.randn(n2, n3)
+ C = rng.randn(n3, n4)
X1 = ot.utils.dots(A, B, C)