diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-11-05 15:57:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-05 15:57:08 +0100 |
commit | 0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 (patch) | |
tree | b0c0fbce0109ba460a67a6356dc0ff03e2b3c1d5 /test/test_bregman.py | |
parent | 0e431c203a66c6d48e6bb1efeda149460472a0f0 (diff) |
[MRG] Tests with types/device on sliced/bregman/gromov functions (#303)
* First draft : making pytest use gpu for torch testing
* bug solve
* Revert "bug solve"
This reverts commit 29b013abd162f8693128f26d8129186b79923609.
* Revert "First draft : making pytest use gpu for torch testing"
This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a.
* sliced
* sliced
* ot 1dsolver
* bregman
* better print
* jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them
* gromov & entropic gromov
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r-- | test/test_bregman.py | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index edfe9c3..830052d 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -278,6 +278,51 @@ def test_sinkhorn_variants(nx): np.testing.assert_allclose(G0, G_green, atol=1e-5) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) +@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str) +@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str) +def test_sinkhorn_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub = nx.from_numpy(u, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, Gb) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) +def test_sinkhorn2_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub = nx.from_numpy(u, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, lossb) + + @pytest.skip_backend("jax") def test_sinkhorn_variants_multi_b(nx): # test sinkhorn |