summaryrefslogtreecommitdiff
path: root/test/test_backend.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-11-02 13:42:02 +0100
committerGitHub <noreply@github.com>2021-11-02 13:42:02 +0100
commita335324d008e8982be61d7ace937815a2bfa98f9 (patch)
tree83c7f637597f10f6f3d20b15532e53fc65b51f22 /test/test_backend.py
parent0cb2b2efe901ed74c614046d250518769f870313 (diff)
[MRG] Backend for gromov (#294)
* bregman: small correction * gromov backend first draft * Removing decorators * Reworked casting method * Bug solve * Removing casting * Bug solve * toarray renamed todense ; expand_dims removed * Warning (jax not supporting sparse matrix) moved * Mistake corrected * test backend * Sparsity test for older versions of pytorch * Trying pytorch/1.10 * Attempt to correct torch sparse bug * Backend version of gromov tests * Random state introduced for remaining gromov functions * review changes * code coverage * Docs (first draft, to be continued) * Gromov docs * Prettified docs * mistake corrected in the docs * little change Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_backend.py')
-rw-r--r--test/test_backend.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index 5853282..0f11ace 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -207,6 +207,22 @@ def test_empty_backend():
nx.stack([M, M])
with pytest.raises(NotImplementedError):
nx.reshape(M, (5, 3, 2))
+ with pytest.raises(NotImplementedError):
+ nx.coo_matrix(M, M, M)
+ with pytest.raises(NotImplementedError):
+ nx.issparse(M)
+ with pytest.raises(NotImplementedError):
+ nx.tocsr(M)
+ with pytest.raises(NotImplementedError):
+ nx.eliminate_zeros(M)
+ with pytest.raises(NotImplementedError):
+ nx.todense(M)
+ with pytest.raises(NotImplementedError):
+ nx.where(M, M, M)
+ with pytest.raises(NotImplementedError):
+ nx.copy(M)
+ with pytest.raises(NotImplementedError):
+ nx.allclose(M, M)
def test_func_backends(nx):
@@ -216,6 +232,11 @@ def test_func_backends(nx):
v = rnd.randn(3)
val = np.array([1.0])
+ # Sparse tensors test
+ sp_row = np.array([0, 3, 1, 0, 3])
+ sp_col = np.array([0, 3, 1, 2, 2])
+ sp_data = np.array([4, 5, 7, 9, 0])
+
lst_tot = []
for nx in [ot.backend.NumpyBackend(), nx]:
@@ -229,6 +250,10 @@ def test_func_backends(nx):
vb = nx.from_numpy(v)
val = nx.from_numpy(val)
+ sp_rowb = nx.from_numpy(sp_row)
+ sp_colb = nx.from_numpy(sp_col)
+ sp_datab = nx.from_numpy(sp_data)
+
A = nx.set_gradients(val, v, v)
lst_b.append(nx.to_numpy(A))
lst_name.append('set_gradients')
@@ -438,6 +463,37 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('reshape')
+ sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4))
+ nx.todense(Mb)
+ lst_b.append(nx.to_numpy(nx.todense(sp_Mb)))
+ lst_name.append('coo_matrix')
+
+ assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)'
+ assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)'
+
+ A = nx.tocsr(sp_Mb)
+ lst_b.append(nx.to_numpy(nx.todense(A)))
+ lst_name.append('tocsr')
+
+ A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('eliminate_zeros (dense)')
+
+ A = nx.eliminate_zeros(sp_Mb)
+ lst_b.append(nx.to_numpy(nx.todense(A)))
+ lst_name.append('eliminate_zeros (sparse)')
+
+ A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('where')
+
+ A = nx.copy(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('copy')
+
+ assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)'
+ assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)'
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]