summaryrefslogtreecommitdiff
path: root/test/test_1d_solver.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_1d_solver.py')
-rw-r--r--test/test_1d_solver.py68
1 files changed, 65 insertions, 3 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index cb85cb9..6a42cfe 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -11,7 +11,7 @@ import pytest
import ot
from ot.lp import wasserstein_1d
-from ot.backend import get_backend_list
+from ot.backend import get_backend_list, tf
from scipy.stats import wasserstein_distance
backend_list = get_backend_list()
@@ -86,7 +86,6 @@ def test_wasserstein_1d(nx):
def test_wasserstein_1d_type_devices(nx):
-
rng = np.random.RandomState(0)
n = 10
@@ -108,6 +107,37 @@ def test_wasserstein_1d_type_devices(nx):
nx.assert_same_dtype_device(xb, res)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_wasserstein_1d_device_tf():
+ if not tf:
+ return
+ nx = ot.backend.TensorflowBackend()
+ rng = np.random.RandomState(0)
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+ nx.assert_same_dtype_device(xb, res)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+ nx.assert_same_dtype_device(xb, res)
+ assert nx.dtype_device(res)[1].startswith("GPU")
+
+
def test_emd_1d_emd2_1d():
# test emd1d gives similar results as emd
n = 20
@@ -148,7 +178,6 @@ def test_emd_1d_emd2_1d():
def test_emd1d_type_devices(nx):
-
rng = np.random.RandomState(0)
n = 10
@@ -170,3 +199,36 @@ def test_emd1d_type_devices(nx):
nx.assert_same_dtype_device(xb, emd)
nx.assert_same_dtype_device(xb, emd2)
+
+
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_emd1d_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ rng = np.random.RandomState(0)
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)
+ assert nx.dtype_device(emd)[1].startswith("GPU")