summaryrefslogtreecommitdiff
path: root/test/test_sliced.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_sliced.py')
-rw-r--r--test/test_sliced.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 245202c..91e0961 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -10,6 +10,7 @@ import pytest
import ot
from ot.sliced import get_random_projections
+from ot.backend import tf
def test_get_random_projections():
@@ -161,6 +162,34 @@ def test_sliced_backend_type_devices(nx):
nx.assert_same_dtype_device(xb, valb)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_sliced_backend_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+ assert nx.dtype_device(valb)[1].startswith("GPU")
+
+
def test_max_sliced_backend(nx):
n = 100
@@ -211,3 +240,31 @@ def test_max_sliced_backend_type_devices(nx):
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
+
+
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_max_sliced_backend_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+ assert nx.dtype_device(valb)[1].startswith("GPU")