From 0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Fri, 5 Nov 2021 15:57:08 +0100 Subject: [MRG] Tests with types/device on sliced/bregman/gromov functions (#303) * First draft : making pytest use gpu for torch testing * bug solve * Revert "bug solve" This reverts commit 29b013abd162f8693128f26d8129186b79923609. * Revert "First draft : making pytest use gpu for torch testing" This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a. * sliced * sliced * ot 1dsolver * bregman * better print * jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them * gromov & entropic gromov --- test/test_ot.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index dc3930a..92f26a7 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -88,8 +88,7 @@ def test_emd_emd2_types_devices(nx): M = ot.dist(x, y) for tp in nx.__type_list__: - - print(tp.dtype) + print(nx.dtype_device(tp)) ab = nx.from_numpy(a, type_as=tp) Mb = nx.from_numpy(M, type_as=tp) @@ -98,9 +97,8 @@ def test_emd_emd2_types_devices(nx): w = ot.emd2(ab, ab, Mb) - assert Gb.dtype == Mb.dtype - if not str(nx) == 'numpy': - assert w.dtype == Mb.dtype + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, w) def test_emd2_gradients(): -- cgit v1.2.3