summaryrefslogtreecommitdiff
path: root/test/test_1d_solver.py
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2021-12-09 17:55:12 +0100
committerGitHub <noreply@github.com>2021-12-09 17:55:12 +0100
commitf8d871e8c6f15009f559ece6a12eb8d8891c60fb (patch)
tree9aa46b2fcc8046c6cddd8e9159a6f607dcf0e1e9 /test/test_1d_solver.py
parentb3dc68feac355fa94c4237f4ecad65edc9f7a7e8 (diff)
[MRG] Tensorflow backend & Benchmarker & Myst_parser (#316)
* First batch of tf methods (to be continued) * Second batch of method (yet to debug) * tensorflow for cpu * add tf requirement * pep8 + bug * small changes * attempt to solve pymanopt bug with tf2 * attempt #2 * attempt #3 * attempt 4 * docstring * correct pep8 violation introduced in merge conflicts resolution * attempt 5 * attempt 6 * just a random try * Revert "just a random try" This reverts commit 8223e768bfe33635549fb66cca2267514a60ebbf. * GPU tests for tensorflow * pep8 * attempt to solve issue with m2r2 * Remove transpose backend method * first draft of benchmarker (need to correct time measurement) * prettier bench table * Bitsize and prettier device methods * prettified table bench * Bug corrected (results were mixed up in the final table) * Better perf counter (for GPU support) * pep8 * EMD bench * solve bug if no GPU available * pep8 * warning about tensorflow numpy api being required in the backend.py docstring * Bug solve in backend docstring * not covering code which requires a GPU * Tensorflow gradients manipulation tested * Number of warmup runs is now customizable * typo * Remove some warnings while building docs * Change prettier_device to device_type in backend * Correct JAX mistakes preventing to see the CPU if a GPU is present * Attempt to solve JAX bug in case no GPU is found * Reworked benchmarks order and results storage & clear GPU after usage by benchmark * Add bench to backend docstring * better benchs * remove useless stuff * Better device_type * Now using MYST_PARSER and solving links issue in the README.md / online docs
Diffstat (limited to 'test/test_1d_solver.py')
-rw-r--r--test/test_1d_solver.py68
1 files changed, 65 insertions, 3 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index cb85cb9..6a42cfe 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -11,7 +11,7 @@ import pytest
import ot
from ot.lp import wasserstein_1d
-from ot.backend import get_backend_list
+from ot.backend import get_backend_list, tf
from scipy.stats import wasserstein_distance
backend_list = get_backend_list()
@@ -86,7 +86,6 @@ def test_wasserstein_1d(nx):
def test_wasserstein_1d_type_devices(nx):
-
rng = np.random.RandomState(0)
n = 10
@@ -108,6 +107,37 @@ def test_wasserstein_1d_type_devices(nx):
nx.assert_same_dtype_device(xb, res)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_wasserstein_1d_device_tf():
+ if not tf:
+ return
+ nx = ot.backend.TensorflowBackend()
+ rng = np.random.RandomState(0)
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+ nx.assert_same_dtype_device(xb, res)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+ nx.assert_same_dtype_device(xb, res)
+ assert nx.dtype_device(res)[1].startswith("GPU")
+
+
def test_emd_1d_emd2_1d():
# test emd1d gives similar results as emd
n = 20
@@ -148,7 +178,6 @@ def test_emd_1d_emd2_1d():
def test_emd1d_type_devices(nx):
-
rng = np.random.RandomState(0)
n = 10
@@ -170,3 +199,36 @@ def test_emd1d_type_devices(nx):
nx.assert_same_dtype_device(xb, emd)
nx.assert_same_dtype_device(xb, emd2)
+
+
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_emd1d_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ rng = np.random.RandomState(0)
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)
+ assert nx.dtype_device(emd)[1].startswith("GPU")