summaryrefslogtreecommitdiff
path: root/test/test_sliced.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_sliced.py')
-rw-r--r--test/test_sliced.py32
1 files changed, 8 insertions, 24 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 91e0961..08ab4fb 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -123,9 +123,7 @@ def test_sliced_backend(nx):
n_projections = 20
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
val0 = ot.sliced_wasserstein_distance(x, y, projections=P)
@@ -153,9 +151,7 @@ def test_sliced_backend_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- yb = nx.from_numpy(y, type_as=tp)
- Pb = nx.from_numpy(P, type_as=tp)
+ xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp)
valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
@@ -174,17 +170,13 @@ def test_sliced_backend_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
assert nx.dtype_device(valb)[1].startswith("GPU")
@@ -203,9 +195,7 @@ def test_max_sliced_backend(nx):
n_projections = 20
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P)
@@ -233,9 +223,7 @@ def test_max_sliced_backend_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- yb = nx.from_numpy(y, type_as=tp)
- Pb = nx.from_numpy(P, type_as=tp)
+ xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp)
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
@@ -254,17 +242,13 @@ def test_max_sliced_backend_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
assert nx.dtype_device(valb)[1].startswith("GPU")