diff options
Diffstat (limited to 'test/test_1d_solver.py')
-rw-r--r-- | test/test_1d_solver.py | 28 |
1 files changed, 7 insertions, 21 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 6a42cfe..20f307a 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -66,9 +66,7 @@ def test_wasserstein_1d(nx): rho_v = np.abs(rng.randn(n)) rho_v /= rho_v.sum() - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) # test 1 : wasserstein_1d should be close to scipy W_1 implementation np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1), @@ -98,9 +96,7 @@ def test_wasserstein_1d_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - rho_ub = nx.from_numpy(rho_u, type_as=tp) - rho_vb = nx.from_numpy(rho_v, type_as=tp) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) @@ -122,17 +118,13 @@ def test_wasserstein_1d_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) nx.assert_same_dtype_device(xb, res) if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) nx.assert_same_dtype_device(xb, res) assert nx.dtype_device(res)[1].startswith("GPU") @@ -190,9 +182,7 @@ def test_emd1d_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - rho_ub = nx.from_numpy(rho_u, type_as=tp) - rho_vb = nx.from_numpy(rho_v, type_as=tp) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) @@ -214,9 +204,7 @@ def test_emd1d_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) nx.assert_same_dtype_device(xb, emd) @@ -224,9 +212,7 @@ def test_emd1d_device_tf(): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) nx.assert_same_dtype_device(xb, emd) |