summaryrefslogtreecommitdiff
path: root/test/test_1d_solver.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_1d_solver.py')
-rw-r--r--test/test_1d_solver.py28
1 files changed, 7 insertions, 21 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 6a42cfe..20f307a 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -66,9 +66,7 @@ def test_wasserstein_1d(nx):
rho_v = np.abs(rng.randn(n))
rho_v /= rho_v.sum()
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
# test 1 : wasserstein_1d should be close to scipy W_1 implementation
np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1),
@@ -98,9 +96,7 @@ def test_wasserstein_1d_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- rho_ub = nx.from_numpy(rho_u, type_as=tp)
- rho_vb = nx.from_numpy(rho_v, type_as=tp)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp)
res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
@@ -122,17 +118,13 @@ def test_wasserstein_1d_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
nx.assert_same_dtype_device(xb, res)
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
nx.assert_same_dtype_device(xb, res)
assert nx.dtype_device(res)[1].startswith("GPU")
@@ -190,9 +182,7 @@ def test_emd1d_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- rho_ub = nx.from_numpy(rho_u, type_as=tp)
- rho_vb = nx.from_numpy(rho_v, type_as=tp)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp)
emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
@@ -214,9 +204,7 @@ def test_emd1d_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
nx.assert_same_dtype_device(xb, emd)
@@ -224,9 +212,7 @@ def test_emd1d_device_tf():
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
nx.assert_same_dtype_device(xb, emd)