summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py19
1 files changed, 7 insertions, 12 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 3e2d845..bb258e2 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -47,8 +47,7 @@ def test_emd_backends(nx):
G = ot.emd(a, a, M)
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
Gb = ot.emd(ab, ab, Mb)
@@ -68,8 +67,7 @@ def test_emd2_backends(nx):
val = ot.emd2(a, a, M)
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
valb = ot.emd2(ab, ab, Mb)
@@ -90,8 +88,7 @@ def test_emd_emd2_types_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- ab = nx.from_numpy(a, type_as=tp)
- Mb = nx.from_numpy(M, type_as=tp)
+ ab, Mb = nx.from_numpy(a, M, type_as=tp)
Gb = ot.emd(ab, ab, Mb)
@@ -117,8 +114,7 @@ def test_emd_emd2_devices_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
Gb = ot.emd(ab, ab, Mb)
w = ot.emd2(ab, ab, Mb)
nx.assert_same_dtype_device(Mb, Gb)
@@ -126,8 +122,7 @@ def test_emd_emd2_devices_tf():
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
Gb = ot.emd(ab, ab, Mb)
w = ot.emd2(ab, ab, Mb)
nx.assert_same_dtype_device(Mb, Gb)
@@ -310,8 +305,8 @@ def test_free_support_barycenter_backends(nx):
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)
- measures_locations2 = [nx.from_numpy(x) for x in measures_locations]
- measures_weights2 = [nx.from_numpy(x) for x in measures_weights]
+ measures_locations2 = nx.from_numpy(*measures_locations)
+ measures_weights2 = nx.from_numpy(*measures_weights)
X_init2 = nx.from_numpy(X_init)
X2 = ot.lp.free_support_barycenter(measures_locations2, measures_weights2, X_init2)