summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py36
1 files changed, 35 insertions, 1 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index c4d7713..53edf4f 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -11,7 +11,7 @@ import pytest
import ot
from ot.datasets import make_1D_gauss as gauss
-from ot.backend import torch
+from ot.backend import torch, tf
def test_emd_dimension_and_mass_mismatch():
@@ -101,6 +101,40 @@ def test_emd_emd2_types_devices(nx):
nx.assert_same_dtype_device(Mb, w)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_emd_emd2_devices_tf():
+ if not tf:
+ return
+ nx = ot.backend.TensorflowBackend()
+
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+ M = ot.dist(x, y)
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+ Gb = ot.emd(ab, ab, Mb)
+ w = ot.emd2(ab, ab, Mb)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+ Gb = ot.emd(ab, ab, Mb)
+ w = ot.emd2(ab, ab, Mb)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
+ assert nx.dtype_device(Gb)[1].startswith("GPU")
+
+
def test_emd2_gradients():
n_samples = 100
n_features = 2