diff options
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r-- | test/test_bregman.py | 46 |
1 files changed, 41 insertions, 5 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index 830052d..6e90aa4 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -12,7 +12,7 @@ import numpy as np import pytest import ot -from ot.backend import torch +from ot.backend import torch, tf @pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) @@ -248,6 +248,7 @@ def test_sinkhorn_empty(): ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True) +@pytest.skip_backend('tf') @pytest.skip_backend("jax") def test_sinkhorn_variants(nx): # test sinkhorn @@ -282,6 +283,8 @@ def test_sinkhorn_variants(nx): "sinkhorn_epsilon_scaling", "greenkhorn", "sinkhorn_log"]) +@pytest.skip_arg(("nx", "method"), ("tf", "sinkhorn_epsilon_scaling"), reason="tf does not support sinkhorn_epsilon_scaling", getter=str) +@pytest.skip_arg(("nx", "method"), ("tf", "greenkhorn"), reason="tf does not support greenkhorn", getter=str) @pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str) @pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str) def test_sinkhorn_variants_dtype_device(nx, method): @@ -323,6 +326,36 @@ def test_sinkhorn2_variants_dtype_device(nx, method): nx.assert_same_dtype_device(Mb, lossb) +@pytest.mark.skipif(not tf, reason="tf not installed") +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) +def test_sinkhorn2_variants_device_tf(method): + nx = ot.backend.TensorflowBackend() + n = 100 + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + M = ot.dist(x, x) + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + ub = nx.from_numpy(u) + Mb = nx.from_numpy(M) + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, lossb) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + ub = nx.from_numpy(u) + Mb = nx.from_numpy(M) + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, lossb) + assert nx.dtype_device(Gb)[1].startswith("GPU") + + +@pytest.skip_backend('tf') @pytest.skip_backend("jax") def test_sinkhorn_variants_multi_b(nx): # test sinkhorn @@ -352,6 +385,7 @@ def test_sinkhorn_variants_multi_b(nx): np.testing.assert_allclose(G0, Gs, atol=1e-05) +@pytest.skip_backend('tf') @pytest.skip_backend("jax") def test_sinkhorn2_variants_multi_b(nx): # test sinkhorn @@ -454,7 +488,7 @@ def test_barycenter(nx, method, verbose, warn): weights_nx = nx.from_numpy(weights) reg = 1e-2 - if nx.__name__ == "jax" and method == "sinkhorn_log": + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method) else: @@ -495,7 +529,7 @@ def test_barycenter_debiased(nx, method, verbose, warn): # wasserstein reg = 1e-2 - if nx.__name__ == "jax" and method == "sinkhorn_log": + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) else: @@ -597,7 +631,7 @@ def test_wasserstein_bary_2d(nx, method): # wasserstein reg = 1e-2 - if nx.__name__ == "jax" and method == "sinkhorn_log": + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) else: @@ -629,7 +663,7 @@ def test_wasserstein_bary_2d_debiased(nx, method): # wasserstein reg = 1e-2 - if nx.__name__ == "jax" and method == "sinkhorn_log": + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) else: @@ -888,6 +922,8 @@ def test_implemented_methods(): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) +@pytest.skip_backend('tf') +@pytest.skip_backend("cupy") @pytest.skip_backend("jax") @pytest.mark.filterwarnings("ignore:Bottleneck") def test_screenkhorn(nx): |