diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-11-05 15:57:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-05 15:57:08 +0100 |
commit | 0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 (patch) | |
tree | b0c0fbce0109ba460a67a6356dc0ff03e2b3c1d5 /test/test_1d_solver.py | |
parent | 0e431c203a66c6d48e6bb1efeda149460472a0f0 (diff) |
[MRG] Tests with types/device on sliced/bregman/gromov functions (#303)
* First draft : making pytest use gpu for torch testing
* bug solve
* Revert "bug solve"
This reverts commit 29b013abd162f8693128f26d8129186b79923609.
* Revert "First draft : making pytest use gpu for torch testing"
This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a.
* sliced
* sliced
* ot 1dsolver
* bregman
* better print
* jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them
* gromov & entropic gromov
Diffstat (limited to 'test/test_1d_solver.py')
-rw-r--r-- | test/test_1d_solver.py | 16 |
1 files changed, 5 insertions, 11 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 77b1234..cb85cb9 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -85,7 +85,6 @@ def test_wasserstein_1d(nx): np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) -@pytest.mark.parametrize('nx', backend_list) def test_wasserstein_1d_type_devices(nx): rng = np.random.RandomState(0) @@ -98,8 +97,7 @@ def test_wasserstein_1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - - print(tp.dtype) + print(nx.dtype_device(tp)) xb = nx.from_numpy(x, type_as=tp) rho_ub = nx.from_numpy(rho_u, type_as=tp) @@ -107,8 +105,7 @@ def test_wasserstein_1d_type_devices(nx): res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) - if not str(nx) == 'numpy': - assert res.dtype == xb.dtype + nx.assert_same_dtype_device(xb, res) def test_emd_1d_emd2_1d(): @@ -162,17 +159,14 @@ def test_emd1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - - print(tp.dtype) + print(nx.dtype_device(tp)) xb = nx.from_numpy(x, type_as=tp) rho_ub = nx.from_numpy(rho_u, type_as=tp) rho_vb = nx.from_numpy(rho_v, type_as=tp) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) - emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) - assert emd.dtype == xb.dtype - if not str(nx) == 'numpy': - assert emd2.dtype == xb.dtype + nx.assert_same_dtype_device(xb, emd) + nx.assert_same_dtype_device(xb, emd2) |