diff options
author | AdrienCorenflos <adrien.corenflos@gmail.com> | 2021-10-22 15:05:14 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-22 14:05:14 +0200 |
commit | d50d8145a5c0cf69d438b018cd5f1b914905e784 (patch) | |
tree | 391692ded33fbb5d2eca643218ff16ba98534edc /test/test_backend.py | |
parent | 14c30d4cfac060ff0bf8c64d4c88c77df32aad86 (diff) |
Add set_gradients method for JAX backend. (#278)
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_backend.py')
-rw-r--r-- | test/test_backend.py | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/test/test_backend.py b/test/test_backend.py index bc5b00c..cbfaf94 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -345,7 +345,8 @@ def test_gradients_backends(): rnd = np.random.RandomState(0) v = rnd.randn(10) - c = rnd.randn(1) + c = rnd.randn() + e = rnd.randn() if torch: @@ -362,3 +363,15 @@ def test_gradients_backends(): assert torch.equal(v2.grad, v2) assert torch.equal(c2.grad, c2) + + if jax: + nx = ot.backend.JaxBackend() + with jax.checking_leaks(): + def fun(a, b, d): + val = b * nx.sum(a ** 4) + d + return nx.set_gradients(val, (a, b, d), (a, b, 2 * d)) + grad_val = jax.grad(fun, argnums=(0, 1, 2))(v, c, e) + + np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4) + np.testing.assert_allclose(grad_val[0], v, atol=1e-4) + np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4) |