diff options
-rw-r--r-- | ot/gpu/bregman.py | 11 | ||||
-rw-r--r-- | test/test_gpu.py | 10 |
2 files changed, 10 insertions, 11 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index 978b307..2e2df83 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -70,17 +70,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, log : dict log dictionary return only if log==True in parameters - Examples - -------- - - >>> import ot - >>> a=[.5,.5] - >>> b=[.5,.5] - >>> M=[[0.,1.],[1.,0.]] - >>> ot.sinkhorn(a,b,M,1) - array([[ 0.36552929, 0.13447071], - [ 0.13447071, 0.36552929]]) - References ---------- diff --git a/test/test_gpu.py b/test/test_gpu.py index 6b7fdd4..47b8b6d 100644 --- a/test/test_gpu.py +++ b/test/test_gpu.py @@ -16,6 +16,16 @@ except ImportError: @pytest.mark.skipif(nogpu, reason="No GPU available") +def test_gpu_old_doctests(): + a = [.5, .5] + b = [.5, .5] + M = [[0., 1.], [1., 0.]] + G = ot.sinkhorn(a, b, M, 1) + np.testing.assert_allclose(G, np.array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]])) + + +@pytest.mark.skipif(nogpu, reason="No GPU available") def test_gpu_dist(): rng = np.random.RandomState(0) |