From 7a65086dd340265d0223eb8ffb5c9a5152a82dff Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Mon, 25 Oct 2021 11:36:21 +0200 Subject: [MRG] Bregman backend (#280) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Bregman * Resolve conflicts * Bug solve * Bregman updated for JAX compatibility * Tests coherence between backend improved * No longer enforcing 64 bits operations on Jax except for tests * Now using mixtures, to make backend dependent tests with less code * Better test skipping code * Pep8 + test optimizations * redundancy removed * Docs * Typo corrected * Typo * Typo * Docs * Docs * pep8 * Backend docs * Prettier docs * Mistake corrected * small changes * Better wording Co-authored-by: RĂ©mi Flamary --- test/test_backend.py | 102 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) (limited to 'test/test_backend.py') diff --git a/test/test_backend.py b/test/test_backend.py index cbfaf94..859da5a 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -1,6 +1,7 @@ """Tests for backend module """ # Author: Remi Flamary +# Nicolas Courty # # License: MIT License @@ -155,6 +156,8 @@ def test_empty_backend(): nx.exp(M) with pytest.raises(NotImplementedError): nx.sqrt(M) + with pytest.raises(NotImplementedError): + nx.power(v, 2) with pytest.raises(NotImplementedError): nx.dot(v, v) with pytest.raises(NotImplementedError): @@ -173,8 +176,38 @@ def test_empty_backend(): nx.sort(M) with pytest.raises(NotImplementedError): nx.argsort(M) + with pytest.raises(NotImplementedError): + nx.searchsorted(v, v) with pytest.raises(NotImplementedError): nx.flip(M) + with pytest.raises(NotImplementedError): + nx.clip(M, -1, 1) + with pytest.raises(NotImplementedError): + nx.repeat(M, 0, 1) + with pytest.raises(NotImplementedError): + nx.take_along_axis(M, v, 0) + with pytest.raises(NotImplementedError): + nx.concatenate([v, v]) + with pytest.raises(NotImplementedError): + nx.zero_pad(M, v) + with pytest.raises(NotImplementedError): + nx.argmax(M) + with pytest.raises(NotImplementedError): + nx.mean(M) + with pytest.raises(NotImplementedError): + nx.std(M) + with pytest.raises(NotImplementedError): + nx.linspace(0, 1, 50) + with pytest.raises(NotImplementedError): + nx.meshgrid(v, v) + with pytest.raises(NotImplementedError): + nx.diag(M) + with pytest.raises(NotImplementedError): + nx.unique([M, M]) + with pytest.raises(NotImplementedError): + nx.logsumexp(M) + with pytest.raises(NotImplementedError): + nx.stack([M, M]) @pytest.mark.parametrize('backend', backend_list) @@ -278,6 +311,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('sqrt') + A = nx.power(Mb, 2) + lst_b.append(nx.to_numpy(A)) + lst_name.append('power') + A = nx.dot(vb, vb) lst_b.append(nx.to_numpy(A)) lst_name.append('dot(v,v)') @@ -326,10 +363,75 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('argsort') + A = nx.searchsorted(Mb, Mb, 'right') + lst_b.append(nx.to_numpy(A)) + lst_name.append('searchsorted') + A = nx.flip(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('flip') + A = nx.clip(vb, 0, 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('clip') + + A = nx.repeat(Mb, 0) + A = nx.repeat(Mb, 2, -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('repeat') + + A = nx.take_along_axis(vb, nx.arange(3), -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('take_along_axis') + + A = nx.concatenate((Mb, Mb), -1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('concatenate') + + A = nx.zero_pad(Mb, len(Mb.shape) * [(3, 3)]) + lst_b.append(nx.to_numpy(A)) + lst_name.append('zero_pad') + + A = nx.argmax(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argmax') + + A = nx.mean(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('mean') + + A = nx.std(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('std') + + A = nx.linspace(0, 1, 50) + lst_b.append(nx.to_numpy(A)) + lst_name.append('linspace') + + X, Y = nx.meshgrid(vb, vb) + lst_b.append(np.stack([nx.to_numpy(X), nx.to_numpy(Y)])) + lst_name.append('meshgrid') + + A = nx.diag(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('diag2D') + + A = nx.diag(vb, 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append('diag1D') + + A = nx.unique(nx.from_numpy(np.stack([M, M]))) + lst_b.append(nx.to_numpy(A)) + lst_name.append('unique') + + A = nx.logsumexp(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('logsumexp') + + A = nx.stack([Mb, Mb]) + lst_b.append(nx.to_numpy(A)) + lst_name.append('stack') + lst_tot.append(lst_b) lst_np = lst_tot[0] -- cgit v1.2.3