From 5e7bfbcbc99ce5915873147677b434c0b1d10fc8 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 24 Sep 2018 15:20:30 +0200 Subject: working test +92 percent tets coverege --- test/test_bregman.py | 2 +- test/test_gpu.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index 01ec655..a141078 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -105,7 +105,7 @@ def test_bary(): ot.bregman.barycenter(A, M, reg, log=True, verbose=True) -def test_wassersteinbary(): +def test_wasserstein_bary_2d(): size = 100 # size of a square image a1 = np.random.randn(size, size) diff --git a/test/test_gpu.py b/test/test_gpu.py index 51a0cff..6b7fdd4 100644 --- a/test/test_gpu.py +++ b/test/test_gpu.py @@ -6,7 +6,6 @@ import numpy as np import ot -import time import pytest try: # test if cudamat installed @@ -31,7 +30,11 @@ def test_gpu_dist(): np.testing.assert_allclose(M, M2, rtol=1e-10) - M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False) + M2 = ot.gpu.dist(a.copy(), b.copy(), metric='euclidean', to_numpy=False) + + # check raise not implemented wrong metric + with pytest.raises(NotImplementedError): + M2 = ot.gpu.dist(a.copy(), b.copy(), metric='cityblock', to_numpy=False) @pytest.mark.skipif(nogpu, reason="No GPU available") @@ -46,6 +49,9 @@ def test_gpu_sinkhorn(): wa = ot.unif(n_samples // 4) wb = ot.unif(n_samples) + wb2 = np.random.rand(n_samples, 20) + wb2 /= wb2.sum(0, keepdims=True) + M = ot.dist(a.copy(), b.copy()) M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False) @@ -56,7 +62,11 @@ def test_gpu_sinkhorn(): np.testing.assert_allclose(G1, G, rtol=1e-10) - G2 = ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False) + # run all on gpu + ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False, log=True) + + # run sinkhorn for multiple targets + ot.gpu.sinkhorn(wa, wb2, M2, reg, to_numpy=False, log=True) @pytest.mark.skipif(nogpu, reason="No GPU available") @@ -83,4 +93,4 @@ def test_gpu_sinkhorn_lpl1(): np.testing.assert_allclose(G1, G, rtol=1e-10) - G2 = ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False) + ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False, log=True) -- cgit v1.2.3