summaryrefslogtreecommitdiff
path: root/test/test_backend.py
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2021-12-03 12:37:05 +0100
committerGitHub <noreply@github.com>2021-12-03 12:37:05 +0100
commitca69658400dc2ef6a7d3e531acffcd107400085f (patch)
treeb77a28821067be5240cec2082fa1f119b1cfd1cd /test/test_backend.py
parentcb510644b2fd65e4ce216a7799ce7401f71548b8 (diff)
[MRG] Cupy backend (#315)
* Cupy backend * pep8 + bug * working even if cupy not installed * attempt to force codecov to ignore cupy because no gpu can be used for testing on github * docstring * readme
Diffstat (limited to 'test/test_backend.py')
-rw-r--r--test/test_backend.py21
1 files changed, 18 insertions, 3 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index 1832b91..2e7eecc 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -7,7 +7,7 @@
import ot
import ot.backend
-from ot.backend import torch, jax
+from ot.backend import torch, jax, cp
import pytest
@@ -87,6 +87,20 @@ def test_get_backend():
with pytest.raises(ValueError):
get_backend(A, B2)
+ if cp:
+ A2 = cp.asarray(A)
+ B2 = cp.asarray(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'cupy'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'cupy'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
def test_convert_between_backends(nx):
@@ -240,7 +254,7 @@ def test_func_backends(nx):
# Sparse tensors test
sp_row = np.array([0, 3, 1, 0, 3])
sp_col = np.array([0, 3, 1, 2, 2])
- sp_data = np.array([4, 5, 7, 9, 0])
+ sp_data = np.array([4, 5, 7, 9, 0], dtype=np.float64)
lst_tot = []
@@ -393,7 +407,8 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('argsort')
- A = nx.searchsorted(Mb, Mb, 'right')
+ tmp = nx.sort(Mb)
+ A = nx.searchsorted(tmp, tmp, 'right')
lst_b.append(nx.to_numpy(A))
lst_name.append('searchsorted')