diff options
Diffstat (limited to 'test/test_backend.py')
-rw-r--r-- | test/test_backend.py | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/test/test_backend.py b/test/test_backend.py index 1832b91..2e7eecc 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -7,7 +7,7 @@ import ot import ot.backend -from ot.backend import torch, jax +from ot.backend import torch, jax, cp import pytest @@ -87,6 +87,20 @@ def test_get_backend(): with pytest.raises(ValueError): get_backend(A, B2) + if cp: + A2 = cp.asarray(A) + B2 = cp.asarray(B) + + nx = get_backend(A2) + assert nx.__name__ == 'cupy' + + nx = get_backend(A2, B2) + assert nx.__name__ == 'cupy' + + # test not unique types in input + with pytest.raises(ValueError): + get_backend(A, B2) + def test_convert_between_backends(nx): @@ -240,7 +254,7 @@ def test_func_backends(nx): # Sparse tensors test sp_row = np.array([0, 3, 1, 0, 3]) sp_col = np.array([0, 3, 1, 2, 2]) - sp_data = np.array([4, 5, 7, 9, 0]) + sp_data = np.array([4, 5, 7, 9, 0], dtype=np.float64) lst_tot = [] @@ -393,7 +407,8 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('argsort') - A = nx.searchsorted(Mb, Mb, 'right') + tmp = nx.sort(Mb) + A = nx.searchsorted(tmp, tmp, 'right') lst_b.append(nx.to_numpy(A)) lst_name.append('searchsorted') |