summaryrefslogtreecommitdiff
path: root/test/test_backend.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-10-25 17:35:36 +0200
committerGitHub <noreply@github.com>2021-10-25 17:35:36 +0200
commit76450dddf8dd62b9714b72e99ae075516246d433 (patch)
tree67de8de1c185cc8e7fc33a1fc0613015824d1fbb /test/test_backend.py
parent7a65086dd340265d0223eb8ffb5c9a5152a82dff (diff)
[MRG] Backend for optim (#282)
* Backend for optim * Bug solve * Doc update * backend tests now with fixture * Unused imports removed * Docs * Docs * Docs * Outer product backend docs * Prettier docs * Pep8 * Mistakes corrected Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_backend.py')
-rw-r--r--test/test_backend.py22
1 files changed, 14 insertions, 8 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index 859da5a..5853282 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -17,9 +17,6 @@ from numpy.testing import assert_array_almost_equal_nulp
from ot.backend import get_backend, get_backend_list, to_numpy
-backend_list = get_backend_list()
-
-
def test_get_backend_list():
lst = get_backend_list()
@@ -28,7 +25,6 @@ def test_get_backend_list():
assert isinstance(lst[0], ot.backend.NumpyBackend)
-@pytest.mark.parametrize('nx', backend_list)
def test_to_numpy(nx):
v = nx.zeros(10)
@@ -92,7 +88,6 @@ def test_get_backend():
get_backend(A, B2)
-@pytest.mark.parametrize('nx', backend_list)
def test_convert_between_backends(nx):
A = np.zeros((3, 2))
@@ -181,6 +176,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.flip(M)
with pytest.raises(NotImplementedError):
+ nx.outer(v, v)
+ with pytest.raises(NotImplementedError):
nx.clip(M, -1, 1)
with pytest.raises(NotImplementedError):
nx.repeat(M, 0, 1)
@@ -208,10 +205,11 @@ def test_empty_backend():
nx.logsumexp(M)
with pytest.raises(NotImplementedError):
nx.stack([M, M])
+ with pytest.raises(NotImplementedError):
+ nx.reshape(M, (5, 3, 2))
-@pytest.mark.parametrize('backend', backend_list)
-def test_func_backends(backend):
+def test_func_backends(nx):
rnd = np.random.RandomState(0)
M = rnd.randn(10, 3)
@@ -220,7 +218,7 @@ def test_func_backends(backend):
lst_tot = []
- for nx in [ot.backend.NumpyBackend(), backend]:
+ for nx in [ot.backend.NumpyBackend(), nx]:
print('Backend: ', nx.__name__)
@@ -371,6 +369,10 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('flip')
+ A = nx.outer(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('outer')
+
A = nx.clip(vb, 0, 1)
lst_b.append(nx.to_numpy(A))
lst_name.append('clip')
@@ -432,6 +434,10 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('stack')
+ A = nx.reshape(Mb, (5, 3, 2))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('reshape')
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]