summaryrefslogtreecommitdiff
path: root/test/test_backend.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-10-25 11:36:21 +0200
committerGitHub <noreply@github.com>2021-10-25 11:36:21 +0200
commit7a65086dd340265d0223eb8ffb5c9a5152a82dff (patch)
tree300f4a1cd645516fba1e440691fe48830d781b5c /test/test_backend.py
parent7af8c2147d61349f4d99ca33318a8a125e4569aa (diff)
[MRG] Bregman backend (#280)
* Bregman * Resolve conflicts * Bug solve * Bregman updated for JAX compatibility * Tests coherence between backend improved * No longer enforcing 64 bits operations on Jax except for tests * Now using mixtures, to make backend dependent tests with less code * Better test skipping code * Pep8 + test optimizations * redundancy removed * Docs * Typo corrected * Typo * Typo * Docs * Docs * pep8 * Backend docs * Prettier docs * Mistake corrected * small changes * Better wording Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_backend.py')
-rw-r--r--test/test_backend.py102
1 files changed, 102 insertions, 0 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index cbfaf94..859da5a 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -1,6 +1,7 @@
"""Tests for backend module """
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
@@ -156,6 +157,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.sqrt(M)
with pytest.raises(NotImplementedError):
+ nx.power(v, 2)
+ with pytest.raises(NotImplementedError):
nx.dot(v, v)
with pytest.raises(NotImplementedError):
nx.norm(M)
@@ -174,7 +177,37 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.argsort(M)
with pytest.raises(NotImplementedError):
+ nx.searchsorted(v, v)
+ with pytest.raises(NotImplementedError):
nx.flip(M)
+ with pytest.raises(NotImplementedError):
+ nx.clip(M, -1, 1)
+ with pytest.raises(NotImplementedError):
+ nx.repeat(M, 0, 1)
+ with pytest.raises(NotImplementedError):
+ nx.take_along_axis(M, v, 0)
+ with pytest.raises(NotImplementedError):
+ nx.concatenate([v, v])
+ with pytest.raises(NotImplementedError):
+ nx.zero_pad(M, v)
+ with pytest.raises(NotImplementedError):
+ nx.argmax(M)
+ with pytest.raises(NotImplementedError):
+ nx.mean(M)
+ with pytest.raises(NotImplementedError):
+ nx.std(M)
+ with pytest.raises(NotImplementedError):
+ nx.linspace(0, 1, 50)
+ with pytest.raises(NotImplementedError):
+ nx.meshgrid(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.diag(M)
+ with pytest.raises(NotImplementedError):
+ nx.unique([M, M])
+ with pytest.raises(NotImplementedError):
+ nx.logsumexp(M)
+ with pytest.raises(NotImplementedError):
+ nx.stack([M, M])
@pytest.mark.parametrize('backend', backend_list)
@@ -278,6 +311,10 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('sqrt')
+ A = nx.power(Mb, 2)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('power')
+
A = nx.dot(vb, vb)
lst_b.append(nx.to_numpy(A))
lst_name.append('dot(v,v)')
@@ -326,10 +363,75 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('argsort')
+ A = nx.searchsorted(Mb, Mb, 'right')
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('searchsorted')
+
A = nx.flip(Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append('flip')
+ A = nx.clip(vb, 0, 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('clip')
+
+ A = nx.repeat(Mb, 0)
+ A = nx.repeat(Mb, 2, -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('repeat')
+
+ A = nx.take_along_axis(vb, nx.arange(3), -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('take_along_axis')
+
+ A = nx.concatenate((Mb, Mb), -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('concatenate')
+
+ A = nx.zero_pad(Mb, len(Mb.shape) * [(3, 3)])
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('zero_pad')
+
+ A = nx.argmax(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argmax')
+
+ A = nx.mean(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('mean')
+
+ A = nx.std(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('std')
+
+ A = nx.linspace(0, 1, 50)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('linspace')
+
+ X, Y = nx.meshgrid(vb, vb)
+ lst_b.append(np.stack([nx.to_numpy(X), nx.to_numpy(Y)]))
+ lst_name.append('meshgrid')
+
+ A = nx.diag(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('diag2D')
+
+ A = nx.diag(vb, 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('diag1D')
+
+ A = nx.unique(nx.from_numpy(np.stack([M, M])))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('unique')
+
+ A = nx.logsumexp(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('logsumexp')
+
+ A = nx.stack([Mb, Mb])
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('stack')
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]