summaryrefslogtreecommitdiff
path: root/test/test_backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_backend.py')
-rw-r--r--test/test_backend.py102
1 files changed, 102 insertions, 0 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index cbfaf94..859da5a 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -1,6 +1,7 @@
"""Tests for backend module """
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
@@ -156,6 +157,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.sqrt(M)
with pytest.raises(NotImplementedError):
+ nx.power(v, 2)
+ with pytest.raises(NotImplementedError):
nx.dot(v, v)
with pytest.raises(NotImplementedError):
nx.norm(M)
@@ -174,7 +177,37 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.argsort(M)
with pytest.raises(NotImplementedError):
+ nx.searchsorted(v, v)
+ with pytest.raises(NotImplementedError):
nx.flip(M)
+ with pytest.raises(NotImplementedError):
+ nx.clip(M, -1, 1)
+ with pytest.raises(NotImplementedError):
+ nx.repeat(M, 0, 1)
+ with pytest.raises(NotImplementedError):
+ nx.take_along_axis(M, v, 0)
+ with pytest.raises(NotImplementedError):
+ nx.concatenate([v, v])
+ with pytest.raises(NotImplementedError):
+ nx.zero_pad(M, v)
+ with pytest.raises(NotImplementedError):
+ nx.argmax(M)
+ with pytest.raises(NotImplementedError):
+ nx.mean(M)
+ with pytest.raises(NotImplementedError):
+ nx.std(M)
+ with pytest.raises(NotImplementedError):
+ nx.linspace(0, 1, 50)
+ with pytest.raises(NotImplementedError):
+ nx.meshgrid(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.diag(M)
+ with pytest.raises(NotImplementedError):
+ nx.unique([M, M])
+ with pytest.raises(NotImplementedError):
+ nx.logsumexp(M)
+ with pytest.raises(NotImplementedError):
+ nx.stack([M, M])
@pytest.mark.parametrize('backend', backend_list)
@@ -278,6 +311,10 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('sqrt')
+ A = nx.power(Mb, 2)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('power')
+
A = nx.dot(vb, vb)
lst_b.append(nx.to_numpy(A))
lst_name.append('dot(v,v)')
@@ -326,10 +363,75 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('argsort')
+ A = nx.searchsorted(Mb, Mb, 'right')
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('searchsorted')
+
A = nx.flip(Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append('flip')
+ A = nx.clip(vb, 0, 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('clip')
+
+ A = nx.repeat(Mb, 0)
+ A = nx.repeat(Mb, 2, -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('repeat')
+
+ A = nx.take_along_axis(vb, nx.arange(3), -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('take_along_axis')
+
+ A = nx.concatenate((Mb, Mb), -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('concatenate')
+
+ A = nx.zero_pad(Mb, len(Mb.shape) * [(3, 3)])
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('zero_pad')
+
+ A = nx.argmax(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argmax')
+
+ A = nx.mean(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('mean')
+
+ A = nx.std(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('std')
+
+ A = nx.linspace(0, 1, 50)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('linspace')
+
+ X, Y = nx.meshgrid(vb, vb)
+ lst_b.append(np.stack([nx.to_numpy(X), nx.to_numpy(Y)]))
+ lst_name.append('meshgrid')
+
+ A = nx.diag(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('diag2D')
+
+ A = nx.diag(vb, 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('diag1D')
+
+ A = nx.unique(nx.from_numpy(np.stack([M, M])))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('unique')
+
+ A = nx.logsumexp(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('logsumexp')
+
+ A = nx.stack([Mb, Mb])
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('stack')
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]