diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-10-25 17:35:36 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-25 17:35:36 +0200 |
commit | 76450dddf8dd62b9714b72e99ae075516246d433 (patch) | |
tree | 67de8de1c185cc8e7fc33a1fc0613015824d1fbb /test/test_backend.py | |
parent | 7a65086dd340265d0223eb8ffb5c9a5152a82dff (diff) |
[MRG] Backend for optim (#282)
* Backend for optim
* Bug solve
* Doc update
* backend tests now with fixture
* Unused imports removed
* Docs
* Docs
* Docs
* Outer product backend docs
* Prettier docs
* Pep8
* Mistakes corrected
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_backend.py')
-rw-r--r-- | test/test_backend.py | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/test/test_backend.py b/test/test_backend.py index 859da5a..5853282 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -17,9 +17,6 @@ from numpy.testing import assert_array_almost_equal_nulp from ot.backend import get_backend, get_backend_list, to_numpy -backend_list = get_backend_list() - - def test_get_backend_list(): lst = get_backend_list() @@ -28,7 +25,6 @@ def test_get_backend_list(): assert isinstance(lst[0], ot.backend.NumpyBackend) -@pytest.mark.parametrize('nx', backend_list) def test_to_numpy(nx): v = nx.zeros(10) @@ -92,7 +88,6 @@ def test_get_backend(): get_backend(A, B2) -@pytest.mark.parametrize('nx', backend_list) def test_convert_between_backends(nx): A = np.zeros((3, 2)) @@ -181,6 +176,8 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.flip(M) with pytest.raises(NotImplementedError): + nx.outer(v, v) + with pytest.raises(NotImplementedError): nx.clip(M, -1, 1) with pytest.raises(NotImplementedError): nx.repeat(M, 0, 1) @@ -208,10 +205,11 @@ def test_empty_backend(): nx.logsumexp(M) with pytest.raises(NotImplementedError): nx.stack([M, M]) + with pytest.raises(NotImplementedError): + nx.reshape(M, (5, 3, 2)) -@pytest.mark.parametrize('backend', backend_list) -def test_func_backends(backend): +def test_func_backends(nx): rnd = np.random.RandomState(0) M = rnd.randn(10, 3) @@ -220,7 +218,7 @@ def test_func_backends(backend): lst_tot = [] - for nx in [ot.backend.NumpyBackend(), backend]: + for nx in [ot.backend.NumpyBackend(), nx]: print('Backend: ', nx.__name__) @@ -371,6 +369,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('flip') + A = nx.outer(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('outer') + A = nx.clip(vb, 0, 1) lst_b.append(nx.to_numpy(A)) lst_name.append('clip') @@ -432,6 +434,10 @@ def test_func_backends(backend): lst_b.append(nx.to_numpy(A)) lst_name.append('stack') + A = nx.reshape(Mb, (5, 3, 2)) + lst_b.append(nx.to_numpy(A)) + lst_name.append('reshape') + lst_tot.append(lst_b) lst_np = lst_tot[0] |