diff options
author | Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> | 2021-12-03 12:37:05 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-03 12:37:05 +0100 |
commit | ca69658400dc2ef6a7d3e531acffcd107400085f (patch) | |
tree | b77a28821067be5240cec2082fa1f119b1cfd1cd /test/test_backend.py | |
parent | cb510644b2fd65e4ce216a7799ce7401f71548b8 (diff) |
[MRG] Cupy backend (#315)
* Cupy backend
* pep8 + bug
* working even if cupy not installed
* attempt to force codecov to ignore cupy because no gpu can be used for testing on github
* docstring
* readme
Diffstat (limited to 'test/test_backend.py')
-rw-r--r-- | test/test_backend.py | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/test/test_backend.py b/test/test_backend.py index 1832b91..2e7eecc 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -7,7 +7,7 @@ import ot import ot.backend -from ot.backend import torch, jax +from ot.backend import torch, jax, cp import pytest @@ -87,6 +87,20 @@ def test_get_backend(): with pytest.raises(ValueError): get_backend(A, B2) + if cp: + A2 = cp.asarray(A) + B2 = cp.asarray(B) + + nx = get_backend(A2) + assert nx.__name__ == 'cupy' + + nx = get_backend(A2, B2) + assert nx.__name__ == 'cupy' + + # test not unique types in input + with pytest.raises(ValueError): + get_backend(A, B2) + def test_convert_between_backends(nx): @@ -240,7 +254,7 @@ def test_func_backends(nx): # Sparse tensors test sp_row = np.array([0, 3, 1, 0, 3]) sp_col = np.array([0, 3, 1, 2, 2]) - sp_data = np.array([4, 5, 7, 9, 0]) + sp_data = np.array([4, 5, 7, 9, 0], dtype=np.float64) lst_tot = [] @@ -393,7 +407,8 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('argsort') - A = nx.searchsorted(Mb, Mb, 'right') + tmp = nx.sort(Mb) + A = nx.searchsorted(tmp, tmp, 'right') lst_b.append(nx.to_numpy(A)) lst_name.append('searchsorted') |