diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-10-25 11:36:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-25 11:36:21 +0200 |
commit | 7a65086dd340265d0223eb8ffb5c9a5152a82dff (patch) | |
tree | 300f4a1cd645516fba1e440691fe48830d781b5c /test/test_backend.py | |
parent | 7af8c2147d61349f4d99ca33318a8a125e4569aa (diff) |
[MRG] Bregman backend (#280)
* Bregman
* Resolve conflicts
* Bug solve
* Bregman updated for JAX compatibility
* Tests coherence between backend improved
* No longer enforcing 64 bits operations on Jax except for tests
* Now using mixtures, to make backend dependent tests with less code
* Better test skipping code
* Pep8 + test optimizations
* redundancy removed
* Docs
* Typo corrected
* Typo
* Typo
* Docs
* Docs
* pep8
* Backend docs
* Prettier docs
* Mistake corrected
* small changes
* Better wording
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_backend.py')
-rw-r--r-- | test/test_backend.py | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/test/test_backend.py b/test/test_backend.py index cbfaf94..859da5a 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -1,6 +1,7 @@ """Tests for backend module """ # Author: Remi Flamary <remi.flamary@polytechnique.edu> +# Nicolas Courty <ncourty@irisa.fr> # # License: MIT License @@ -156,6 +157,8 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.sqrt(M) with pytest.raises(NotImplementedError): + nx.power(v, 2) + with pytest.raises(NotImplementedError): nx.dot(v, v) with pytest.raises(NotImplementedError): nx.norm(M) @@ -174,7 +177,37 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.argsort(M) with pytest.raises(NotImplementedError): + nx.searchsorted(v, v) + with pytest.raises(NotImplementedError): nx.flip(M) + with pytest.raises(NotImplementedError): + nx.clip(M, -1, 1) + with pytest.raises(NotImplementedError): + nx.repeat(M, 0, 1) + with pytest.raises(NotImplementedError): + nx.take_along_axis(M, v, 0) + with pytest.raises(NotImplementedError): + nx.concatenate([v, v]) + with pytest.raises(NotImplementedError): + nx.zero_pad(M, v) + with pytest.raises(NotImplementedError): + nx.argmax(M) + with pytest.raises(NotImplementedError): + nx.mean(M) + with pytest.raises(NotImplementedError): + nx.std(M) + with pytest.raises(NotImplementedError): + nx.linspace(0, 1, 50) + with pytest.raises(NotImplementedError): + nx.meshgrid(v, v) + with pytest.raises(NotImplementedError): + nx.diag(M) + with pytest.raises(NotImplementedError): + nx.unique([M, M]) + with pytest.raises(NotImplementedError): + nx.logsumexp(M) + with pytest.raises(NotImplementedError): + nx.stack([M, M]) @pytest.mark.parametrize('backend', backend_list) @@ -278,6 +311,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('sqrt') + A = nx.power(Mb, 2) + lst_b.append(nx.to_numpy(A)) + lst_name.append('power') + A = nx.dot(vb, vb) lst_b.append(nx.to_numpy(A)) lst_name.append('dot(v,v)') @@ -326,10 +363,75 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('argsort') + A = nx.searchsorted(Mb, Mb, 'right') + lst_b.append(nx.to_numpy(A)) + lst_name.append('searchsorted') + A = nx.flip(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('flip') + A = nx.clip(vb, 0, 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('clip') + + A = nx.repeat(Mb, 0) + A = nx.repeat(Mb, 2, -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('repeat') + + A = nx.take_along_axis(vb, nx.arange(3), -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('take_along_axis') + + A = nx.concatenate((Mb, Mb), -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('concatenate') + + A = nx.zero_pad(Mb, len(Mb.shape) * [(3, 3)]) + lst_b.append(nx.to_numpy(A)) + lst_name.append('zero_pad') + + A = nx.argmax(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argmax') + + A = nx.mean(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('mean') + + A = nx.std(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('std') + + A = nx.linspace(0, 1, 50) + lst_b.append(nx.to_numpy(A)) + lst_name.append('linspace') + + X, Y = nx.meshgrid(vb, vb) + lst_b.append(np.stack([nx.to_numpy(X), nx.to_numpy(Y)])) + lst_name.append('meshgrid') + + A = nx.diag(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('diag2D') + + A = nx.diag(vb, 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('diag1D') + + A = nx.unique(nx.from_numpy(np.stack([M, M]))) + lst_b.append(nx.to_numpy(A)) + lst_name.append('unique') + + A = nx.logsumexp(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('logsumexp') + + A = nx.stack([Mb, Mb]) + lst_b.append(nx.to_numpy(A)) + lst_name.append('stack') + lst_tot.append(lst_b) lst_np = lst_tot[0] |