summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2021-12-09 17:55:12 +0100
committerGitHub <noreply@github.com>2021-12-09 17:55:12 +0100
commitf8d871e8c6f15009f559ece6a12eb8d8891c60fb (patch)
tree9aa46b2fcc8046c6cddd8e9159a6f607dcf0e1e9 /test/test_bregman.py
parentb3dc68feac355fa94c4237f4ecad65edc9f7a7e8 (diff)
[MRG] Tensorflow backend & Benchmarker & Myst_parser (#316)
* First batch of tf methods (to be continued) * Second batch of method (yet to debug) * tensorflow for cpu * add tf requirement * pep8 + bug * small changes * attempt to solve pymanopt bug with tf2 * attempt #2 * attempt #3 * attempt 4 * docstring * correct pep8 violation introduced in merge conflicts resolution * attempt 5 * attempt 6 * just a random try * Revert "just a random try" This reverts commit 8223e768bfe33635549fb66cca2267514a60ebbf. * GPU tests for tensorflow * pep8 * attempt to solve issue with m2r2 * Remove transpose backend method * first draft of benchmarker (need to correct time measurement) * prettier bench table * Bitsize and prettier device methods * prettified table bench * Bug corrected (results were mixed up in the final table) * Better perf counter (for GPU support) * pep8 * EMD bench * solve bug if no GPU available * pep8 * warning about tensorflow numpy api being required in the backend.py docstring * Bug solve in backend docstring * not covering code which requires a GPU * Tensorflow gradients manipulation tested * Number of warmup runs is now customizable * typo * Remove some warnings while building docs * Change prettier_device to device_type in backend * Correct JAX mistakes preventing to see the CPU if a GPU is present * Attempt to solve JAX bug in case no GPU is found * Reworked benchmarks order and results storage & clear GPU after usage by benchmark * Add bench to backend docstring * better benchs * remove useless stuff * Better device_type * Now using MYST_PARSER and solving links issue in the README.md / online docs
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py45
1 files changed, 40 insertions, 5 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index f42ac6f..6e90aa4 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -12,7 +12,7 @@ import numpy as np
import pytest
import ot
-from ot.backend import torch
+from ot.backend import torch, tf
@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
@@ -248,6 +248,7 @@ def test_sinkhorn_empty():
ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+@pytest.skip_backend('tf')
@pytest.skip_backend("jax")
def test_sinkhorn_variants(nx):
# test sinkhorn
@@ -282,6 +283,8 @@ def test_sinkhorn_variants(nx):
"sinkhorn_epsilon_scaling",
"greenkhorn",
"sinkhorn_log"])
+@pytest.skip_arg(("nx", "method"), ("tf", "sinkhorn_epsilon_scaling"), reason="tf does not support sinkhorn_epsilon_scaling", getter=str)
+@pytest.skip_arg(("nx", "method"), ("tf", "greenkhorn"), reason="tf does not support greenkhorn", getter=str)
@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str)
@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str)
def test_sinkhorn_variants_dtype_device(nx, method):
@@ -323,6 +326,36 @@ def test_sinkhorn2_variants_dtype_device(nx, method):
nx.assert_same_dtype_device(Mb, lossb)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"])
+def test_sinkhorn2_variants_device_tf(method):
+ nx = ot.backend.TensorflowBackend()
+ n = 100
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+ M = ot.dist(x, x)
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ ub = nx.from_numpy(u)
+ Mb = nx.from_numpy(M)
+ Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, lossb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ ub = nx.from_numpy(u)
+ Mb = nx.from_numpy(M)
+ Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, lossb)
+ assert nx.dtype_device(Gb)[1].startswith("GPU")
+
+
+@pytest.skip_backend('tf')
@pytest.skip_backend("jax")
def test_sinkhorn_variants_multi_b(nx):
# test sinkhorn
@@ -352,6 +385,7 @@ def test_sinkhorn_variants_multi_b(nx):
np.testing.assert_allclose(G0, Gs, atol=1e-05)
+@pytest.skip_backend('tf')
@pytest.skip_backend("jax")
def test_sinkhorn2_variants_multi_b(nx):
# test sinkhorn
@@ -454,7 +488,7 @@ def test_barycenter(nx, method, verbose, warn):
weights_nx = nx.from_numpy(weights)
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method)
else:
@@ -495,7 +529,7 @@ def test_barycenter_debiased(nx, method, verbose, warn):
# wasserstein
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method)
else:
@@ -597,7 +631,7 @@ def test_wasserstein_bary_2d(nx, method):
# wasserstein
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
@@ -629,7 +663,7 @@ def test_wasserstein_bary_2d_debiased(nx, method):
# wasserstein
reg = 1e-2
- if nx.__name__ == "jax" and method == "sinkhorn_log":
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
else:
@@ -888,6 +922,7 @@ def test_implemented_methods():
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+@pytest.skip_backend('tf')
@pytest.skip_backend("cupy")
@pytest.skip_backend("jax")
@pytest.mark.filterwarnings("ignore:Bottleneck")