summaryrefslogtreecommitdiff
path: root/test/test_backend.py
diff options
context:
space:
mode:
authorAdrienCorenflos <adrien.corenflos@gmail.com>2021-10-22 15:05:14 +0300
committerGitHub <noreply@github.com>2021-10-22 14:05:14 +0200
commitd50d8145a5c0cf69d438b018cd5f1b914905e784 (patch)
tree391692ded33fbb5d2eca643218ff16ba98534edc /test/test_backend.py
parent14c30d4cfac060ff0bf8c64d4c88c77df32aad86 (diff)
Add set_gradients method for JAX backend. (#278)
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_backend.py')
-rw-r--r--test/test_backend.py15
1 files changed, 14 insertions, 1 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index bc5b00c..cbfaf94 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -345,7 +345,8 @@ def test_gradients_backends():
rnd = np.random.RandomState(0)
v = rnd.randn(10)
- c = rnd.randn(1)
+ c = rnd.randn()
+ e = rnd.randn()
if torch:
@@ -362,3 +363,15 @@ def test_gradients_backends():
assert torch.equal(v2.grad, v2)
assert torch.equal(c2.grad, c2)
+
+ if jax:
+ nx = ot.backend.JaxBackend()
+ with jax.checking_leaks():
+ def fun(a, b, d):
+ val = b * nx.sum(a ** 4) + d
+ return nx.set_gradients(val, (a, b, d), (a, b, 2 * d))
+ grad_val = jax.grad(fun, argnums=(0, 1, 2))(v, c, e)
+
+ np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4)
+ np.testing.assert_allclose(grad_val[0], v, atol=1e-4)
+ np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4)