diff options
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r-- | test/test_bregman.py | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index edfe9c3..830052d 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -278,6 +278,51 @@ def test_sinkhorn_variants(nx): np.testing.assert_allclose(G0, G_green, atol=1e-5) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) +@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str) +@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str) +def test_sinkhorn_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub = nx.from_numpy(u, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, Gb) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) +def test_sinkhorn2_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub = nx.from_numpy(u, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, lossb) + + @pytest.skip_backend("jax") def test_sinkhorn_variants_multi_b(nx): # test sinkhorn |