diff options
Diffstat (limited to 'test/test_gpu.py')
-rw-r--r-- | test/test_gpu.py | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/test/test_gpu.py b/test/test_gpu.py index 312a2d4..49b98d0 100644 --- a/test/test_gpu.py +++ b/test/test_gpu.py @@ -3,8 +3,14 @@ import numpy as np import time import pytest +try: # test if cudamat installed + import ot.gpu + nogpu = False +except ImportError: + nogpu = True + -@pytest.mark.skip(reason="No way to test GPU on travis yet") +@pytest.mark.skipif(nogpu, reason="No GPU available") def test_gpu_sinkhorn(): import ot.gpu @@ -12,7 +18,7 @@ def test_gpu_sinkhorn(): print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format( np.min(r), np.max(r), np.mean(r), np.std(r))) - for n in [5000]: + for n in [50, 100, 500, 1000]: print(n) a = np.random.rand(n // 4, 100) b = np.random.rand(n, 100) @@ -30,14 +36,16 @@ def test_gpu_sinkhorn(): print(" GPU sinkhorn, time: {:6.2f} sec ".format(time3 - time2)) describeRes(G2) + assert np.allclose(G1, G2, rtol=1e-5, atol=1e-5) -@pytest.mark.skip(reason="No way to test GPU on travis yet") + +@pytest.mark.skipif(nogpu, reason="No GPU available") def test_gpu_sinkhorn_lpl1(): def describeRes(r): print("min:{:.3E}, max:{:.3E}, mean:{:.3E}, std:{:.3E}" .format(np.min(r), np.max(r), np.mean(r), np.std(r))) - for n in [5000]: + for n in [50, 100, 500, 1000]: print(n) a = np.random.rand(n // 4, 100) labels_a = np.random.randint(10, size=(n // 4)) @@ -57,3 +65,5 @@ def test_gpu_sinkhorn_lpl1(): print(" GPU sinkhorn lpl1, time: {:6.2f} sec ".format( time3 - time2)) describeRes(G2) + + assert np.allclose(G1, G2, rtol=1e-5, atol=1e-5) |