diff options
Diffstat (limited to 'test/test_sliced.py')
-rw-r--r-- | test/test_sliced.py | 32 |
1 files changed, 8 insertions, 24 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py index 91e0961..08ab4fb 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -123,9 +123,7 @@ def test_sliced_backend(nx): n_projections = 20 - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) val0 = ot.sliced_wasserstein_distance(x, y, projections=P) @@ -153,9 +151,7 @@ def test_sliced_backend_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - yb = nx.from_numpy(y, type_as=tp) - Pb = nx.from_numpy(P, type_as=tp) + xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) @@ -174,17 +170,13 @@ def test_sliced_backend_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") @@ -203,9 +195,7 @@ def test_max_sliced_backend(nx): n_projections = 20 - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P) @@ -233,9 +223,7 @@ def test_max_sliced_backend_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - yb = nx.from_numpy(y, type_as=tp) - Pb = nx.from_numpy(P, type_as=tp) + xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp) valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) @@ -254,17 +242,13 @@ def test_max_sliced_backend_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") |