summaryrefslogtreecommitdiff
path: root/test/test_1d_solver.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_1d_solver.py')
-rw-r--r--test/test_1d_solver.py16
1 files changed, 5 insertions, 11 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 77b1234..cb85cb9 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -85,7 +85,6 @@ def test_wasserstein_1d(nx):
np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4)
-@pytest.mark.parametrize('nx', backend_list)
def test_wasserstein_1d_type_devices(nx):
rng = np.random.RandomState(0)
@@ -98,8 +97,7 @@ def test_wasserstein_1d_type_devices(nx):
rho_v /= rho_v.sum()
for tp in nx.__type_list__:
-
- print(tp.dtype)
+ print(nx.dtype_device(tp))
xb = nx.from_numpy(x, type_as=tp)
rho_ub = nx.from_numpy(rho_u, type_as=tp)
@@ -107,8 +105,7 @@ def test_wasserstein_1d_type_devices(nx):
res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
- if not str(nx) == 'numpy':
- assert res.dtype == xb.dtype
+ nx.assert_same_dtype_device(xb, res)
def test_emd_1d_emd2_1d():
@@ -162,17 +159,14 @@ def test_emd1d_type_devices(nx):
rho_v /= rho_v.sum()
for tp in nx.__type_list__:
-
- print(tp.dtype)
+ print(nx.dtype_device(tp))
xb = nx.from_numpy(x, type_as=tp)
rho_ub = nx.from_numpy(rho_u, type_as=tp)
rho_vb = nx.from_numpy(rho_v, type_as=tp)
emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
-
emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
- assert emd.dtype == xb.dtype
- if not str(nx) == 'numpy':
- assert emd2.dtype == xb.dtype
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)