summaryrefslogtreecommitdiff
path: root/test/test_backend.py
diff options
context:
space:
mode:
authorNicolas Courty <ncourty@irisa.fr>2021-11-02 14:19:57 +0100
committerGitHub <noreply@github.com>2021-11-02 14:19:57 +0100
commit6775a527f9d3c801f8cdd805d8f205b6a75551b9 (patch)
treec0ed5a7c297b4003688fec52d46f918ea0086a7d /test/test_backend.py
parenta335324d008e8982be61d7ace937815a2bfa98f9 (diff)
[MRG] Sliced and 1D Wasserstein distances : backend versions (#256)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend * new backend functions for sliced * small indent pb * Optimized backendversion of sliced W * error in sliced W * after master merge * error sliced * error sliced * pep8 * test_sliced pep8 * doctest + precision for sliced * doctest * type win test_backend gather * type win test_backend gather * Update sliced.py change argument of padding pad_width * Update backend.py update redefinition * Update backend.py pep8 * Update backend.py pep 8 again.... * pep8 * build docs * emd2_1D example * refectoring emd_1d and variants * remove unused previous wasserstein_1d * pep8 * upate example * move stuff * tesys should work + implemù random backend * test random generayor functions * correction * better random generation * update sliced * update sliced * proper tests sliced * max sliced * chae file nam * add stuff * example sliced flow and barycenter * correct typo + update readme * exemple sliced flow done * pep8 * solver1d works * pep8 Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'test/test_backend.py')
-rw-r--r--test/test_backend.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index 0f11ace..1832b91 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -208,6 +208,11 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.reshape(M, (5, 3, 2))
with pytest.raises(NotImplementedError):
+ nx.seed(42)
+ with pytest.raises(NotImplementedError):
+ nx.rand()
+ with pytest.raises(NotImplementedError):
+ nx.randn()
nx.coo_matrix(M, M, M)
with pytest.raises(NotImplementedError):
nx.issparse(M)
@@ -248,6 +253,7 @@ def test_func_backends(nx):
Mb = nx.from_numpy(M)
vb = nx.from_numpy(v)
+
val = nx.from_numpy(val)
sp_rowb = nx.from_numpy(sp_row)
@@ -255,6 +261,7 @@ def test_func_backends(nx):
sp_datab = nx.from_numpy(sp_data)
A = nx.set_gradients(val, v, v)
+
lst_b.append(nx.to_numpy(A))
lst_name.append('set_gradients')
@@ -505,6 +512,35 @@ def test_func_backends(nx):
assert np.allclose(a1, a2, atol=1e-7)
+def test_random_backends(nx):
+
+ tmp_u = nx.rand()
+
+ assert tmp_u < 1
+
+ tmp_n = nx.randn()
+
+ nx.seed(0)
+ M1 = nx.to_numpy(nx.rand(5, 2))
+ nx.seed(0)
+ M2 = nx.to_numpy(nx.rand(5, 2, type_as=tmp_n))
+
+ assert np.all(M1 >= 0)
+ assert np.all(M1 < 1)
+ assert M1.shape == (5, 2)
+ assert np.allclose(M1, M2)
+
+ nx.seed(0)
+ M1 = nx.to_numpy(nx.randn(5, 2))
+ nx.seed(0)
+ M2 = nx.to_numpy(nx.randn(5, 2, type_as=tmp_u))
+
+ nx.seed(42)
+ v1 = nx.randn()
+ v2 = nx.randn()
+ assert v1 != v2
+
+
def test_gradients_backends():
rnd = np.random.RandomState(0)