summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index dc3930a..92f26a7 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -88,8 +88,7 @@ def test_emd_emd2_types_devices(nx):
M = ot.dist(x, y)
for tp in nx.__type_list__:
-
- print(tp.dtype)
+ print(nx.dtype_device(tp))
ab = nx.from_numpy(a, type_as=tp)
Mb = nx.from_numpy(M, type_as=tp)
@@ -98,9 +97,8 @@ def test_emd_emd2_types_devices(nx):
w = ot.emd2(ab, ab, Mb)
- assert Gb.dtype == Mb.dtype
- if not str(nx) == 'numpy':
- assert w.dtype == Mb.dtype
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
def test_emd2_gradients():