summaryrefslogtreecommitdiff
path: root/test/test_backend.py
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2021-12-09 17:55:12 +0100
committerGitHub <noreply@github.com>2021-12-09 17:55:12 +0100
commitf8d871e8c6f15009f559ece6a12eb8d8891c60fb (patch)
tree9aa46b2fcc8046c6cddd8e9159a6f607dcf0e1e9 /test/test_backend.py
parentb3dc68feac355fa94c4237f4ecad65edc9f7a7e8 (diff)
[MRG] Tensorflow backend & Benchmarker & Myst_parser (#316)
* First batch of tf methods (to be continued) * Second batch of method (yet to debug) * tensorflow for cpu * add tf requirement * pep8 + bug * small changes * attempt to solve pymanopt bug with tf2 * attempt #2 * attempt #3 * attempt 4 * docstring * correct pep8 violation introduced in merge conflicts resolution * attempt 5 * attempt 6 * just a random try * Revert "just a random try" This reverts commit 8223e768bfe33635549fb66cca2267514a60ebbf. * GPU tests for tensorflow * pep8 * attempt to solve issue with m2r2 * Remove transpose backend method * first draft of benchmarker (need to correct time measurement) * prettier bench table * Bitsize and prettier device methods * prettified table bench * Bug corrected (results were mixed up in the final table) * Better perf counter (for GPU support) * pep8 * EMD bench * solve bug if no GPU available * pep8 * warning about tensorflow numpy api being required in the backend.py docstring * Bug solve in backend docstring * not covering code which requires a GPU * Tensorflow gradients manipulation tested * Number of warmup runs is now customizable * typo * Remove some warnings while building docs * Change prettier_device to device_type in backend * Correct JAX mistakes preventing to see the CPU if a GPU is present * Attempt to solve JAX bug in case no GPU is found * Reworked benchmarks order and results storage & clear GPU after usage by benchmark * Add bench to backend docstring * better benchs * remove useless stuff * Better device_type * Now using MYST_PARSER and solving links issue in the README.md / online docs
Diffstat (limited to 'test/test_backend.py')
-rw-r--r--test/test_backend.py52
1 files changed, 50 insertions, 2 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index 2e7eecc..027c4cd 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -7,7 +7,7 @@
import ot
import ot.backend
-from ot.backend import torch, jax, cp
+from ot.backend import torch, jax, cp, tf
import pytest
@@ -101,6 +101,20 @@ def test_get_backend():
with pytest.raises(ValueError):
get_backend(A, B2)
+ if tf:
+ A2 = tf.convert_to_tensor(A)
+ B2 = tf.convert_to_tensor(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'tf'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'tf'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
def test_convert_between_backends(nx):
@@ -242,6 +256,14 @@ def test_empty_backend():
nx.copy(M)
with pytest.raises(NotImplementedError):
nx.allclose(M, M)
+ with pytest.raises(NotImplementedError):
+ nx.squeeze(M)
+ with pytest.raises(NotImplementedError):
+ nx.bitsize(M)
+ with pytest.raises(NotImplementedError):
+ nx.device_type(M)
+ with pytest.raises(NotImplementedError):
+ nx._bench(lambda x: x, M, n_runs=1)
def test_func_backends(nx):
@@ -491,7 +513,7 @@ def test_func_backends(nx):
lst_name.append('coo_matrix')
assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)'
- assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)'
+ assert nx.issparse(sp_Mb) or nx.__name__ in ("jax", "tf"), 'Assert fail on: issparse (expected True)'
A = nx.tocsr(sp_Mb)
lst_b.append(nx.to_numpy(nx.todense(A)))
@@ -516,6 +538,18 @@ def test_func_backends(nx):
assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)'
assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)'
+ A = nx.squeeze(nx.zeros((3, 1, 4, 1)))
+ assert tuple(A.shape) == (3, 4), 'Assert fail on: squeeze'
+
+ A = nx.bitsize(Mb)
+ lst_b.append(float(A))
+ lst_name.append("bitsize")
+
+ A = nx.device_type(Mb)
+ assert A in ("CPU", "GPU")
+
+ nx._bench(lambda x: x, M, n_runs=1)
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]
@@ -590,3 +624,17 @@ def test_gradients_backends():
np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4)
np.testing.assert_allclose(grad_val[0], v, atol=1e-4)
np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4)
+
+ if tf:
+ nx = ot.backend.TensorflowBackend()
+ w = tf.Variable(tf.random.normal((3, 2)), name='w')
+ b = tf.Variable(tf.random.normal((2,), dtype=tf.float32), name='b')
+ x = tf.random.normal((1, 3), dtype=tf.float32)
+
+ with tf.GradientTape() as tape:
+ y = x @ w + b
+ loss = tf.reduce_mean(y ** 2)
+ manipulated_loss = nx.set_gradients(loss, (w, b), (w, b))
+ [dl_dw, dl_db] = tape.gradient(manipulated_loss, [w, b])
+ assert nx.allclose(dl_dw, w)
+ assert nx.allclose(dl_db, b)