summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-09-24 15:20:30 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-09-24 15:20:30 +0200
commit5e7bfbcbc99ce5915873147677b434c0b1d10fc8 (patch)
treeaa0e9b0b86d1bd5efd56c19d7c6ca41196ec5f12 /test
parent75e78022d2df350ea220cee1b5e759ef9fc35a5b (diff)
working test +92 percent tets coverege
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py2
-rw-r--r--test/test_gpu.py18
2 files changed, 15 insertions, 5 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 01ec655..a141078 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -105,7 +105,7 @@ def test_bary():
ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
-def test_wassersteinbary():
+def test_wasserstein_bary_2d():
size = 100 # size of a square image
a1 = np.random.randn(size, size)
diff --git a/test/test_gpu.py b/test/test_gpu.py
index 51a0cff..6b7fdd4 100644
--- a/test/test_gpu.py
+++ b/test/test_gpu.py
@@ -6,7 +6,6 @@
import numpy as np
import ot
-import time
import pytest
try: # test if cudamat installed
@@ -31,7 +30,11 @@ def test_gpu_dist():
np.testing.assert_allclose(M, M2, rtol=1e-10)
- M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False)
+ M2 = ot.gpu.dist(a.copy(), b.copy(), metric='euclidean', to_numpy=False)
+
+ # check raise not implemented wrong metric
+ with pytest.raises(NotImplementedError):
+ M2 = ot.gpu.dist(a.copy(), b.copy(), metric='cityblock', to_numpy=False)
@pytest.mark.skipif(nogpu, reason="No GPU available")
@@ -46,6 +49,9 @@ def test_gpu_sinkhorn():
wa = ot.unif(n_samples // 4)
wb = ot.unif(n_samples)
+ wb2 = np.random.rand(n_samples, 20)
+ wb2 /= wb2.sum(0, keepdims=True)
+
M = ot.dist(a.copy(), b.copy())
M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False)
@@ -56,7 +62,11 @@ def test_gpu_sinkhorn():
np.testing.assert_allclose(G1, G, rtol=1e-10)
- G2 = ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False)
+ # run all on gpu
+ ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False, log=True)
+
+ # run sinkhorn for multiple targets
+ ot.gpu.sinkhorn(wa, wb2, M2, reg, to_numpy=False, log=True)
@pytest.mark.skipif(nogpu, reason="No GPU available")
@@ -83,4 +93,4 @@ def test_gpu_sinkhorn_lpl1():
np.testing.assert_allclose(G1, G, rtol=1e-10)
- G2 = ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False)
+ ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False, log=True)