From 6775a527f9d3c801f8cdd805d8f205b6a75551b9 Mon Sep 17 00:00:00 2001 From: Nicolas Courty Date: Tue, 2 Nov 2021 14:19:57 +0100 Subject: [MRG] Sliced and 1D Wasserstein distances : backend versions (#256) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update ot/utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend * new backend functions for sliced * small indent pb * Optimized backendversion of sliced W * error in sliced W * after master merge * error sliced * error sliced * pep8 * test_sliced pep8 * doctest + precision for sliced * doctest * type win test_backend gather * type win test_backend gather * Update sliced.py change argument of padding pad_width * Update backend.py update redefinition * Update backend.py pep8 * Update backend.py pep 8 again.... * pep8 * build docs * emd2_1D example * refectoring emd_1d and variants * remove unused previous wasserstein_1d * pep8 * upate example * move stuff * tesys should work + implemù random backend * test random generayor functions * correction * better random generation * update sliced * update sliced * proper tests sliced * max sliced * chae file nam * add stuff * example sliced flow and barycenter * correct typo + update readme * exemple sliced flow done * pep8 * solver1d works * pep8 Co-authored-by: Rémi Flamary Co-authored-by: Alexandre Gramfort --- test/test_backend.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) (limited to 'test/test_backend.py') diff --git a/test/test_backend.py b/test/test_backend.py index 0f11ace..1832b91 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -208,6 +208,11 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.reshape(M, (5, 3, 2)) with pytest.raises(NotImplementedError): + nx.seed(42) + with pytest.raises(NotImplementedError): + nx.rand() + with pytest.raises(NotImplementedError): + nx.randn() nx.coo_matrix(M, M, M) with pytest.raises(NotImplementedError): nx.issparse(M) @@ -248,6 +253,7 @@ def test_func_backends(nx): Mb = nx.from_numpy(M) vb = nx.from_numpy(v) + val = nx.from_numpy(val) sp_rowb = nx.from_numpy(sp_row) @@ -255,6 +261,7 @@ def test_func_backends(nx): sp_datab = nx.from_numpy(sp_data) A = nx.set_gradients(val, v, v) + lst_b.append(nx.to_numpy(A)) lst_name.append('set_gradients') @@ -505,6 +512,35 @@ def test_func_backends(nx): assert np.allclose(a1, a2, atol=1e-7) +def test_random_backends(nx): + + tmp_u = nx.rand() + + assert tmp_u < 1 + + tmp_n = nx.randn() + + nx.seed(0) + M1 = nx.to_numpy(nx.rand(5, 2)) + nx.seed(0) + M2 = nx.to_numpy(nx.rand(5, 2, type_as=tmp_n)) + + assert np.all(M1 >= 0) + assert np.all(M1 < 1) + assert M1.shape == (5, 2) + assert np.allclose(M1, M2) + + nx.seed(0) + M1 = nx.to_numpy(nx.randn(5, 2)) + nx.seed(0) + M2 = nx.to_numpy(nx.randn(5, 2, type_as=tmp_u)) + + nx.seed(42) + v1 = nx.randn() + v2 = nx.randn() + assert v1 != v2 + + def test_gradients_backends(): rnd = np.random.RandomState(0) -- cgit v1.2.3