diff options
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index dc3930a..92f26a7 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -88,8 +88,7 @@ def test_emd_emd2_types_devices(nx): M = ot.dist(x, y) for tp in nx.__type_list__: - - print(tp.dtype) + print(nx.dtype_device(tp)) ab = nx.from_numpy(a, type_as=tp) Mb = nx.from_numpy(M, type_as=tp) @@ -98,9 +97,8 @@ def test_emd_emd2_types_devices(nx): w = ot.emd2(ab, ab, Mb) - assert Gb.dtype == Mb.dtype - if not str(nx) == 'numpy': - assert w.dtype == Mb.dtype + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, w) def test_emd2_gradients(): |