diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-11-02 13:42:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-02 13:42:02 +0100 |
commit | a335324d008e8982be61d7ace937815a2bfa98f9 (patch) | |
tree | 83c7f637597f10f6f3d20b15532e53fc65b51f22 /test/test_backend.py | |
parent | 0cb2b2efe901ed74c614046d250518769f870313 (diff) |
[MRG] Backend for gromov (#294)
* bregman: small correction
* gromov backend first draft
* Removing decorators
* Reworked casting method
* Bug solve
* Removing casting
* Bug solve
* toarray renamed todense ; expand_dims removed
* Warning (jax not supporting sparse matrix) moved
* Mistake corrected
* test backend
* Sparsity test for older versions of pytorch
* Trying pytorch/1.10
* Attempt to correct torch sparse bug
* Backend version of gromov tests
* Random state introduced for remaining gromov functions
* review changes
* code coverage
* Docs (first draft, to be continued)
* Gromov docs
* Prettified docs
* mistake corrected in the docs
* little change
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_backend.py')
-rw-r--r-- | test/test_backend.py | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/test/test_backend.py b/test/test_backend.py index 5853282..0f11ace 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -207,6 +207,22 @@ def test_empty_backend(): nx.stack([M, M]) with pytest.raises(NotImplementedError): nx.reshape(M, (5, 3, 2)) + with pytest.raises(NotImplementedError): + nx.coo_matrix(M, M, M) + with pytest.raises(NotImplementedError): + nx.issparse(M) + with pytest.raises(NotImplementedError): + nx.tocsr(M) + with pytest.raises(NotImplementedError): + nx.eliminate_zeros(M) + with pytest.raises(NotImplementedError): + nx.todense(M) + with pytest.raises(NotImplementedError): + nx.where(M, M, M) + with pytest.raises(NotImplementedError): + nx.copy(M) + with pytest.raises(NotImplementedError): + nx.allclose(M, M) def test_func_backends(nx): @@ -216,6 +232,11 @@ def test_func_backends(nx): v = rnd.randn(3) val = np.array([1.0]) + # Sparse tensors test + sp_row = np.array([0, 3, 1, 0, 3]) + sp_col = np.array([0, 3, 1, 2, 2]) + sp_data = np.array([4, 5, 7, 9, 0]) + lst_tot = [] for nx in [ot.backend.NumpyBackend(), nx]: @@ -229,6 +250,10 @@ def test_func_backends(nx): vb = nx.from_numpy(v) val = nx.from_numpy(val) + sp_rowb = nx.from_numpy(sp_row) + sp_colb = nx.from_numpy(sp_col) + sp_datab = nx.from_numpy(sp_data) + A = nx.set_gradients(val, v, v) lst_b.append(nx.to_numpy(A)) lst_name.append('set_gradients') @@ -438,6 +463,37 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('reshape') + sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4)) + nx.todense(Mb) + lst_b.append(nx.to_numpy(nx.todense(sp_Mb))) + lst_name.append('coo_matrix') + + assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)' + assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)' + + A = nx.tocsr(sp_Mb) + lst_b.append(nx.to_numpy(nx.todense(A))) + lst_name.append('tocsr') + + A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.) + lst_b.append(nx.to_numpy(A)) + lst_name.append('eliminate_zeros (dense)') + + A = nx.eliminate_zeros(sp_Mb) + lst_b.append(nx.to_numpy(nx.todense(A))) + lst_name.append('eliminate_zeros (sparse)') + + A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('where') + + A = nx.copy(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('copy') + + assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)' + assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)' + lst_tot.append(lst_b) lst_np = lst_tot[0] |