diff options
Diffstat (limited to 'test/test_sliced.py')
-rw-r--r-- | test/test_sliced.py | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py index 0bd74ec..245202c 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -139,6 +139,28 @@ def test_sliced_backend(nx): assert np.allclose(val0, valb) +def test_sliced_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + yb = nx.from_numpy(y, type_as=tp) + Pb = nx.from_numpy(P, type_as=tp) + + valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) + + nx.assert_same_dtype_device(xb, valb) + + def test_max_sliced_backend(nx): n = 100 @@ -167,3 +189,25 @@ def test_max_sliced_backend(nx): valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)) assert np.allclose(val0, valb) + + +def test_max_sliced_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + yb = nx.from_numpy(y, type_as=tp) + Pb = nx.from_numpy(P, type_as=tp) + + valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) + + nx.assert_same_dtype_device(xb, valb) |