summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-07-01 11:06:26 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-07-01 11:06:26 +0200
commitb05d315b0994d328029d4a4fc082f6994e7f06d1 (patch)
treed3b4c38aac983ad770dd53937d51cd2a3141392c
parent93a74fe4d477e1735e9ce21ee4113281f58b4dcf (diff)
Moved GPU doctests to test_gpu for tests not to fail if no GPU available
-rw-r--r--ot/gpu/bregman.py11
-rw-r--r--test/test_gpu.py10
2 files changed, 10 insertions, 11 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index 978b307..2e2df83 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -70,17 +70,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
log : dict
log dictionary return only if log==True in parameters
- Examples
- --------
-
- >>> import ot
- >>> a=[.5,.5]
- >>> b=[.5,.5]
- >>> M=[[0.,1.],[1.,0.]]
- >>> ot.sinkhorn(a,b,M,1)
- array([[ 0.36552929, 0.13447071],
- [ 0.13447071, 0.36552929]])
-
References
----------
diff --git a/test/test_gpu.py b/test/test_gpu.py
index 6b7fdd4..47b8b6d 100644
--- a/test/test_gpu.py
+++ b/test/test_gpu.py
@@ -16,6 +16,16 @@ except ImportError:
@pytest.mark.skipif(nogpu, reason="No GPU available")
+def test_gpu_old_doctests():
+ a = [.5, .5]
+ b = [.5, .5]
+ M = [[0., 1.], [1., 0.]]
+ G = ot.sinkhorn(a, b, M, 1)
+ np.testing.assert_allclose(G, np.array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]]))
+
+
+@pytest.mark.skipif(nogpu, reason="No GPU available")
def test_gpu_dist():
rng = np.random.RandomState(0)