summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2021-12-09 17:55:12 +0100
committerGitHub <noreply@github.com>2021-12-09 17:55:12 +0100
commitf8d871e8c6f15009f559ece6a12eb8d8891c60fb (patch)
tree9aa46b2fcc8046c6cddd8e9159a6f607dcf0e1e9 /test
parentb3dc68feac355fa94c4237f4ecad65edc9f7a7e8 (diff)
[MRG] Tensorflow backend & Benchmarker & Myst_parser (#316)
* First batch of tf methods (to be continued) * Second batch of method (yet to debug) * tensorflow for cpu * add tf requirement * pep8 + bug * small changes * attempt to solve pymanopt bug with tf2 * attempt #2 * attempt #3 * attempt 4 * docstring * correct pep8 violation introduced in merge conflicts resolution * attempt 5 * attempt 6 * just a random try * Revert "just a random try" This reverts commit 8223e768bfe33635549fb66cca2267514a60ebbf. * GPU tests for tensorflow * pep8 * attempt to solve issue with m2r2 * Remove transpose backend method * first draft of benchmarker (need to correct time measurement) * prettier bench table * Bitsize and prettier device methods * prettified table bench * Bug corrected (results were mixed up in the final table) * Better perf counter (for GPU support) * pep8 * EMD bench * solve bug if no GPU available * pep8 * warning about tensorflow numpy api being required in the backend.py docstring * Bug solve in backend docstring * not covering code which requires a GPU * Tensorflow gradients manipulation tested * Number of warmup runs is now customizable * typo * Remove some warnings while building docs * Change prettier_device to device_type in backend * Correct JAX mistakes preventing to see the CPU if a GPU is present * Attempt to solve JAX bug in case no GPU is found * Reworked benchmarks order and results storage & clear GPU after usage by benchmark * Add bench to backend docstring * better benchs * remove useless stuff * Better device_type * Now using MYST_PARSER and solving links issue in the README.md / online docs
Diffstat (limited to 'test')
-rw-r--r--test/conftest.py12
-rw-r--r--test/test_1d_solver.py68
-rw-r--r--test/test_backend.py52
-rw-r--r--test/test_bregman.py45
-rw-r--r--test/test_gromov.py44
-rw-r--r--test/test_ot.py36
-rw-r--r--test/test_sliced.py57
7 files changed, 298 insertions, 16 deletions
diff --git a/test/conftest.py b/test/conftest.py
index 987d98e..c0db8ab 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -5,7 +5,7 @@
# License: MIT License
import pytest
-from ot.backend import jax
+from ot.backend import jax, tf
from ot.backend import get_backend_list
import functools
@@ -13,6 +13,10 @@ if jax:
from jax.config import config
config.update("jax_enable_x64", True)
+if tf:
+ from tensorflow.python.ops.numpy_ops import np_config
+ np_config.enable_numpy_behavior()
+
backend_list = get_backend_list()
@@ -24,16 +28,16 @@ def nx(request):
def skip_arg(arg, value, reason=None, getter=lambda x: x):
- if isinstance(arg, tuple) or isinstance(arg, list):
+ if isinstance(arg, (tuple, list)):
n = len(arg)
else:
arg = (arg, )
n = 1
- if n != 1 and (isinstance(value, tuple) or isinstance(value, list)):
+ if n != 1 and isinstance(value, (tuple, list)):
pass
else:
value = (value, )
- if isinstance(getter, tuple) or isinstance(value, list):
+ if isinstance(getter, (tuple, list)):
pass
else:
getter = [getter] * n
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index cb85cb9..6a42cfe 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -11,7 +11,7 @@ import pytest
import ot
from ot.lp import wasserstein_1d
-from ot.backend import get_backend_list
+from ot.backend import get_backend_list, tf
from scipy.stats import wasserstein_distance
backend_list = get_backend_list()
@@ -86,7 +86,6 @@ def test_wasserstein_1d(nx):
def test_wasserstein_1d_type_devices(nx):
-
rng = np.random.RandomState(0)
n = 10
@@ -108,6 +107,37 @@ def test_wasserstein_1d_type_devices(nx):
nx.assert_same_dtype_device(xb, res)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_wasserstein_1d_device_tf():
+ if not tf:
+ return
+ nx = ot.backend.TensorflowBackend()
+ rng = np.random.RandomState(0)
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+ nx.assert_same_dtype_device(xb, res)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+ nx.assert_same_dtype_device(xb, res)
+ assert nx.dtype_device(res)[1].startswith("GPU")
+
+
def test_emd_1d_emd2_1d():
# test emd1d gives similar results as emd
n = 20
@@ -148,7 +178,6 @@ def test_emd_1d_emd2_1d():
def test_emd1d_type_devices(nx):
-
rng = np.random.RandomState(0)
n = 10
@@ -170,3 +199,36 @@ def test_emd1d_type_devices(nx):
nx.assert_same_dtype_device(xb, emd)
nx.assert_same_dtype_device(xb, emd2)
+
+
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_emd1d_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ rng = np.random.RandomState(0)
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)
+ assert nx.dtype_device(emd)[1].startswith("GPU")
diff --git a/test/test_backend.py b/test/test_backend.py
index 2e7eecc..027c4cd 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -7,7 +7,7 @@
import ot
import ot.backend
-from ot.backend import torch, jax, cp
+from ot.backend import torch, jax, cp, tf
import pytest
@@ -101,6 +101,20 @@ def test_get_backend():
with pytest.raises(ValueError):
get_backend(A, B2)
+ if tf:
+ A2 = tf.convert_to_tensor(A)
+ B2 = tf.convert_to_tensor(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'tf'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'tf'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
def test_convert_between_backends(nx):
@@ -242,6 +256,14 @@ def test_empty_backend():
nx.copy(M)
with pytest.raises(NotImplementedError):
nx.allclose(M, M)
+ with pytest.raises(NotImplementedError):
+ nx.squeeze(M)
+ with pytest.raises(NotImplementedError):
+ nx.bitsize(M)
+ with pytest.raises(NotImplementedError):
+ nx.device_type(M)
+ with pytest.raises(NotImplementedError):
+ nx._bench(lambda x: x, M, n_runs=1)
def test_func_backends(nx):
@@ -491,7 +513,7 @@ def test_func_backends(nx):
lst_name.append('coo_matrix')
assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)'
- assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)'
+ assert nx.issparse(sp_Mb) or nx.__name__ in ("jax", "tf"), 'Assert fail on: issparse (expected True)'
A = nx.tocsr(sp_Mb)
lst_b.append(nx.to_numpy(nx.todense(A)))
@@ -516,6 +538,18 @@ def test_func_backends(nx):
assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)'
assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)'
+ A = nx.squeeze(nx.zeros((3, 1, 4, 1)))
+ assert tuple(A.shape) == (3, 4), 'Assert fail on: squeeze'
+
+ A = nx.bitsize(Mb)
+ lst_b.append(float(A))
+ lst_name.append("bitsize")
+
+ A = nx.device_type(Mb)
+ assert A in ("CPU", "GPU")
+
+ nx._bench(lambda x: x, M, n_runs=1)
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]
@@ -590,3 +624,17 @@ def test_gradients_backends():
np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4)
np.testing.assert_allclose(grad_val[0], v, atol=1e-4)
np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4)
+
+ if tf:
+ nx = ot.backend.TensorflowBackend()
+ w = tf.Variable(tf.random.normal((3, 2)), name='w')
+ b = tf.Variable(tf.random.normal((2,), dtype=tf.float32), name='b')
+ x = tf.random.normal((1, 3), dtype=tf.float32)
+
+ with tf.GradientTape() as tape:
+ y = x @ w + b
+ loss = tf.reduce_mean(y ** 2)
+ manipulated_loss = nx.set_gradients(loss, (w, b), (w, b))
+ [dl_dw, dl_db] = tape.gradient(manipulated_loss, [w, b])
+ assert nx.allclose(dl_dw, w)
+ assert nx.allclose(dl_db, b)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index f42ac6f..6e90aa4 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -12,7 +12,7 @@ import numpy as np
import pytest
import ot
-from ot.backend import torch
+from ot.backend import torch, tf
@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
@@ -248,6 +248,7 @@ def test_sinkhorn_empty():
ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+@pytest.skip_backend('tf')
@pytest.skip_backend("jax")
def test_sinkhorn_variants(nx):
# test sinkhorn
@@ -282,6 +283,8 @@ def test_sinkhorn_variants(nx):
"sinkhorn_epsilon_scaling",
"greenkhorn",
"sinkhorn_log"])
+@pytest.skip_arg(("nx", "method"), ("tf", "sinkhorn_epsilon_scaling"), reason="tf does not support sinkhorn_epsilon_scaling", getter=str)
+@pytest.skip_arg(("nx", "method"), ("tf", "greenkhorn"), reason="tf does not support greenkhorn", getter=str)
@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str)
@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str)
def test_sinkhorn_variants_dtype_device(nx, method):
@@ -323,6 +326,36 @@ def test_sinkhorn2_variants_dtype_device(nx, method):
nx.assert_same_dtype_device(Mb, lossb)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"])
+def test_sinkhorn2_variants_device_tf(method):
+ nx = ot.backend.TensorflowBackend()
+ n = 100
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+ M = ot.dist(x, x)
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ ub = nx.from_numpy(u)
+ Mb = nx.from_numpy(M)
+ Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, lossb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ ub = nx.from_numpy(u)
+ Mb = nx.from_numpy(M)
+ Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, lossb)
+ assert nx.dtype_device(Gb)[1].startswith("GPU")
+
+
+@pytest.skip_backend('tf')
@pytest.skip_backend("jax")
def test_sinkhorn_variants_multi_b(nx):
# test sinkhorn
@@ -352,6 +385,7 @@ def test_sinkhorn_variants_multi_b(nx):
np.testing.assert_allclose(G0, Gs, atol=1e-05)
+@pytest.skip_backend('tf')
@pytest.skip_backend("jax")
def test_sinkhorn2_variants_multi_b(nx):
# test sinkhorn
@@ -454,7 +488,7 @@ def test_barycenter(nx, method, verbose, warn):
weights_nx = nx.from_numpy(weights)
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method)
else:
@@ -495,7 +529,7 @@ def test_barycenter_debiased(nx, method, verbose, warn):
# wasserstein
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method)
else:
@@ -597,7 +631,7 @@ def test_wasserstein_bary_2d(nx, method):
# wasserstein
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
@@ -629,7 +663,7 @@ def test_wasserstein_bary_2d_debiased(nx, method):
# wasserstein
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
else:
@@ -888,6 +922,7 @@ def test_implemented_methods():
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+@pytest.skip_backend('tf')
@pytest.skip_backend("cupy")
@pytest.skip_backend("jax")
@pytest.mark.filterwarnings("ignore:Bottleneck")
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 38a7fd7..4b995d5 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -9,7 +9,7 @@
import numpy as np
import ot
from ot.backend import NumpyBackend
-from ot.backend import torch
+from ot.backend import torch, tf
import pytest
@@ -113,6 +113,45 @@ def test_gromov_dtype_device(nx):
nx.assert_same_dtype_device(C1b, gw_valb)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_gromov_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ n_samples = 50 # nb samples
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+ xt = xs[::-1].copy()
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+ assert nx.dtype_device(Gb)[1].startswith("GPU")
+
+
def test_gromov2_gradients():
n_samples = 50 # nb samples
@@ -150,6 +189,7 @@ def test_gromov2_gradients():
@pytest.skip_backend("jax", reason="test very slow with jax backend")
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov(nx):
n_samples = 50 # nb samples
@@ -208,6 +248,7 @@ def test_entropic_gromov(nx):
@pytest.skip_backend("jax", reason="test very slow with jax backend")
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov_dtype_device(nx):
# setup
n_samples = 50 # nb samples
@@ -306,6 +347,7 @@ def test_pointwise_gromov(nx):
np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0015952535464736394, atol=1e-8)
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
@pytest.skip_backend("jax", reason="test very slow with jax backend")
def test_sampled_gromov(nx):
n_samples = 50 # nb samples
diff --git a/test/test_ot.py b/test/test_ot.py
index c4d7713..53edf4f 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -11,7 +11,7 @@ import pytest
import ot
from ot.datasets import make_1D_gauss as gauss
-from ot.backend import torch
+from ot.backend import torch, tf
def test_emd_dimension_and_mass_mismatch():
@@ -101,6 +101,40 @@ def test_emd_emd2_types_devices(nx):
nx.assert_same_dtype_device(Mb, w)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_emd_emd2_devices_tf():
+ if not tf:
+ return
+ nx = ot.backend.TensorflowBackend()
+
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+ M = ot.dist(x, y)
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+ Gb = ot.emd(ab, ab, Mb)
+ w = ot.emd2(ab, ab, Mb)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+ Gb = ot.emd(ab, ab, Mb)
+ w = ot.emd2(ab, ab, Mb)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
+ assert nx.dtype_device(Gb)[1].startswith("GPU")
+
+
def test_emd2_gradients():
n_samples = 100
n_features = 2
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 245202c..91e0961 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -10,6 +10,7 @@ import pytest
import ot
from ot.sliced import get_random_projections
+from ot.backend import tf
def test_get_random_projections():
@@ -161,6 +162,34 @@ def test_sliced_backend_type_devices(nx):
nx.assert_same_dtype_device(xb, valb)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_sliced_backend_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+ assert nx.dtype_device(valb)[1].startswith("GPU")
+
+
def test_max_sliced_backend(nx):
n = 100
@@ -211,3 +240,31 @@ def test_max_sliced_backend_type_devices(nx):
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
+
+
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_max_sliced_backend_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+ valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
+ nx.assert_same_dtype_device(xb, valb)
+ assert nx.dtype_device(valb)[1].startswith("GPU")