summaryrefslogtreecommitdiff
path: root/test/test_gpu.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-09-24 15:20:30 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-09-24 15:20:30 +0200
commit5e7bfbcbc99ce5915873147677b434c0b1d10fc8 (patch)
treeaa0e9b0b86d1bd5efd56c19d7c6ca41196ec5f12 /test/test_gpu.py
parent75e78022d2df350ea220cee1b5e759ef9fc35a5b (diff)
working test +92 percent tets coverege
Diffstat (limited to 'test/test_gpu.py')
-rw-r--r--test/test_gpu.py18
1 files changed, 14 insertions, 4 deletions
diff --git a/test/test_gpu.py b/test/test_gpu.py
index 51a0cff..6b7fdd4 100644
--- a/test/test_gpu.py
+++ b/test/test_gpu.py
@@ -6,7 +6,6 @@
import numpy as np
import ot
-import time
import pytest
try: # test if cudamat installed
@@ -31,7 +30,11 @@ def test_gpu_dist():
np.testing.assert_allclose(M, M2, rtol=1e-10)
- M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False)
+ M2 = ot.gpu.dist(a.copy(), b.copy(), metric='euclidean', to_numpy=False)
+
+ # check raise not implemented wrong metric
+ with pytest.raises(NotImplementedError):
+ M2 = ot.gpu.dist(a.copy(), b.copy(), metric='cityblock', to_numpy=False)
@pytest.mark.skipif(nogpu, reason="No GPU available")
@@ -46,6 +49,9 @@ def test_gpu_sinkhorn():
wa = ot.unif(n_samples // 4)
wb = ot.unif(n_samples)
+ wb2 = np.random.rand(n_samples, 20)
+ wb2 /= wb2.sum(0, keepdims=True)
+
M = ot.dist(a.copy(), b.copy())
M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False)
@@ -56,7 +62,11 @@ def test_gpu_sinkhorn():
np.testing.assert_allclose(G1, G, rtol=1e-10)
- G2 = ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False)
+ # run all on gpu
+ ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False, log=True)
+
+ # run sinkhorn for multiple targets
+ ot.gpu.sinkhorn(wa, wb2, M2, reg, to_numpy=False, log=True)
@pytest.mark.skipif(nogpu, reason="No GPU available")
@@ -83,4 +93,4 @@ def test_gpu_sinkhorn_lpl1():
np.testing.assert_allclose(G1, G, rtol=1e-10)
- G2 = ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False)
+ ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False, log=True)