summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-11-05 15:57:08 +0100
committerGitHub <noreply@github.com>2021-11-05 15:57:08 +0100
commit0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 (patch)
treeb0c0fbce0109ba460a67a6356dc0ff03e2b3c1d5 /test/test_bregman.py
parent0e431c203a66c6d48e6bb1efeda149460472a0f0 (diff)
[MRG] Tests with types/device on sliced/bregman/gromov functions (#303)
* First draft : making pytest use gpu for torch testing * bug solve * Revert "bug solve" This reverts commit 29b013abd162f8693128f26d8129186b79923609. * Revert "First draft : making pytest use gpu for torch testing" This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a. * sliced * sliced * ot 1dsolver * bregman * better print * jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them * gromov & entropic gromov
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index edfe9c3..830052d 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -278,6 +278,51 @@ def test_sinkhorn_variants(nx):
np.testing.assert_allclose(G0, G_green, atol=1e-5)
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_epsilon_scaling",
+ "greenkhorn",
+ "sinkhorn_log"])
+@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str)
+@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str)
+def test_sinkhorn_variants_dtype_device(nx, method):
+ n = 100
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ ub = nx.from_numpy(u, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
+
+ Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+
+ nx.assert_same_dtype_device(Mb, Gb)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"])
+def test_sinkhorn2_variants_dtype_device(nx, method):
+ n = 100
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ ub = nx.from_numpy(u, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
+
+ lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+
+ nx.assert_same_dtype_device(Mb, lossb)
+
+
@pytest.skip_backend("jax")
def test_sinkhorn_variants_multi_b(nx):
# test sinkhorn