summaryrefslogtreecommitdiff
path: root/test/test_sliced.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_sliced.py')
-rw-r--r--test/test_sliced.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 0bd74ec..245202c 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -139,6 +139,28 @@ def test_sliced_backend(nx):
assert np.allclose(val0, valb)
+def test_sliced_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+ yb = nx.from_numpy(y, type_as=tp)
+ Pb = nx.from_numpy(P, type_as=tp)
+
+ valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
+
+ nx.assert_same_dtype_device(xb, valb)
+
+
def test_max_sliced_backend(nx):
n = 100
@@ -167,3 +189,25 @@ def test_max_sliced_backend(nx):
valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb))
assert np.allclose(val0, valb)
+
+
+def test_max_sliced_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+ yb = nx.from_numpy(y, type_as=tp)
+ Pb = nx.from_numpy(P, type_as=tp)
+
+ valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
+
+ nx.assert_same_dtype_device(xb, valb)