summaryrefslogtreecommitdiff
path: root/test/test_gpu.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_gpu.py')
-rw-r--r--test/test_gpu.py18
1 files changed, 14 insertions, 4 deletions
diff --git a/test/test_gpu.py b/test/test_gpu.py
index 51a0cff..6b7fdd4 100644
--- a/test/test_gpu.py
+++ b/test/test_gpu.py
@@ -6,7 +6,6 @@
import numpy as np
import ot
-import time
import pytest
try: # test if cudamat installed
@@ -31,7 +30,11 @@ def test_gpu_dist():
np.testing.assert_allclose(M, M2, rtol=1e-10)
- M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False)
+ M2 = ot.gpu.dist(a.copy(), b.copy(), metric='euclidean', to_numpy=False)
+
+ # check raise not implemented wrong metric
+ with pytest.raises(NotImplementedError):
+ M2 = ot.gpu.dist(a.copy(), b.copy(), metric='cityblock', to_numpy=False)
@pytest.mark.skipif(nogpu, reason="No GPU available")
@@ -46,6 +49,9 @@ def test_gpu_sinkhorn():
wa = ot.unif(n_samples // 4)
wb = ot.unif(n_samples)
+ wb2 = np.random.rand(n_samples, 20)
+ wb2 /= wb2.sum(0, keepdims=True)
+
M = ot.dist(a.copy(), b.copy())
M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False)
@@ -56,7 +62,11 @@ def test_gpu_sinkhorn():
np.testing.assert_allclose(G1, G, rtol=1e-10)
- G2 = ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False)
+ # run all on gpu
+ ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False, log=True)
+
+ # run sinkhorn for multiple targets
+ ot.gpu.sinkhorn(wa, wb2, M2, reg, to_numpy=False, log=True)
@pytest.mark.skipif(nogpu, reason="No GPU available")
@@ -83,4 +93,4 @@ def test_gpu_sinkhorn_lpl1():
np.testing.assert_allclose(G1, G, rtol=1e-10)
- G2 = ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False)
+ ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False, log=True)