diff options
Diffstat (limited to 'test/test_backend.py')
-rw-r--r-- | test/test_backend.py | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/test/test_backend.py b/test/test_backend.py index cbfaf94..859da5a 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -1,6 +1,7 @@ """Tests for backend module """ # Author: Remi Flamary <remi.flamary@polytechnique.edu> +# Nicolas Courty <ncourty@irisa.fr> # # License: MIT License @@ -156,6 +157,8 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.sqrt(M) with pytest.raises(NotImplementedError): + nx.power(v, 2) + with pytest.raises(NotImplementedError): nx.dot(v, v) with pytest.raises(NotImplementedError): nx.norm(M) @@ -174,7 +177,37 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.argsort(M) with pytest.raises(NotImplementedError): + nx.searchsorted(v, v) + with pytest.raises(NotImplementedError): nx.flip(M) + with pytest.raises(NotImplementedError): + nx.clip(M, -1, 1) + with pytest.raises(NotImplementedError): + nx.repeat(M, 0, 1) + with pytest.raises(NotImplementedError): + nx.take_along_axis(M, v, 0) + with pytest.raises(NotImplementedError): + nx.concatenate([v, v]) + with pytest.raises(NotImplementedError): + nx.zero_pad(M, v) + with pytest.raises(NotImplementedError): + nx.argmax(M) + with pytest.raises(NotImplementedError): + nx.mean(M) + with pytest.raises(NotImplementedError): + nx.std(M) + with pytest.raises(NotImplementedError): + nx.linspace(0, 1, 50) + with pytest.raises(NotImplementedError): + nx.meshgrid(v, v) + with pytest.raises(NotImplementedError): + nx.diag(M) + with pytest.raises(NotImplementedError): + nx.unique([M, M]) + with pytest.raises(NotImplementedError): + nx.logsumexp(M) + with pytest.raises(NotImplementedError): + nx.stack([M, M]) @pytest.mark.parametrize('backend', backend_list) @@ -278,6 +311,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('sqrt') + A = nx.power(Mb, 2) + lst_b.append(nx.to_numpy(A)) + lst_name.append('power') + A = nx.dot(vb, vb) lst_b.append(nx.to_numpy(A)) lst_name.append('dot(v,v)') @@ -326,10 +363,75 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('argsort') + A = nx.searchsorted(Mb, Mb, 'right') + lst_b.append(nx.to_numpy(A)) + lst_name.append('searchsorted') + A = nx.flip(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('flip') + A = nx.clip(vb, 0, 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('clip') + + A = nx.repeat(Mb, 0) + A = nx.repeat(Mb, 2, -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('repeat') + + A = nx.take_along_axis(vb, nx.arange(3), -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('take_along_axis') + + A = nx.concatenate((Mb, Mb), -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('concatenate') + + A = nx.zero_pad(Mb, len(Mb.shape) * [(3, 3)]) + lst_b.append(nx.to_numpy(A)) + lst_name.append('zero_pad') + + A = nx.argmax(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argmax') + + A = nx.mean(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('mean') + + A = nx.std(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('std') + + A = nx.linspace(0, 1, 50) + lst_b.append(nx.to_numpy(A)) + lst_name.append('linspace') + + X, Y = nx.meshgrid(vb, vb) + lst_b.append(np.stack([nx.to_numpy(X), nx.to_numpy(Y)])) + lst_name.append('meshgrid') + + A = nx.diag(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('diag2D') + + A = nx.diag(vb, 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('diag1D') + + A = nx.unique(nx.from_numpy(np.stack([M, M]))) + lst_b.append(nx.to_numpy(A)) + lst_name.append('unique') + + A = nx.logsumexp(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('logsumexp') + + A = nx.stack([Mb, Mb]) + lst_b.append(nx.to_numpy(A)) + lst_name.append('stack') + lst_tot.append(lst_b) lst_np = lst_tot[0] |