summaryrefslogtreecommitdiff
path: root/test/test_backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_backend.py')
-rw-r--r--test/test_backend.py21
1 files changed, 18 insertions, 3 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index 1832b91..2e7eecc 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -7,7 +7,7 @@
import ot
import ot.backend
-from ot.backend import torch, jax
+from ot.backend import torch, jax, cp
import pytest
@@ -87,6 +87,20 @@ def test_get_backend():
with pytest.raises(ValueError):
get_backend(A, B2)
+ if cp:
+ A2 = cp.asarray(A)
+ B2 = cp.asarray(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'cupy'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'cupy'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
def test_convert_between_backends(nx):
@@ -240,7 +254,7 @@ def test_func_backends(nx):
# Sparse tensors test
sp_row = np.array([0, 3, 1, 0, 3])
sp_col = np.array([0, 3, 1, 2, 2])
- sp_data = np.array([4, 5, 7, 9, 0])
+ sp_data = np.array([4, 5, 7, 9, 0], dtype=np.float64)
lst_tot = []
@@ -393,7 +407,8 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('argsort')
- A = nx.searchsorted(Mb, Mb, 'right')
+ tmp = nx.sort(Mb)
+ A = nx.searchsorted(tmp, tmp, 'right')
lst_b.append(nx.to_numpy(A))
lst_name.append('searchsorted')