From d50d8145a5c0cf69d438b018cd5f1b914905e784 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Fri, 22 Oct 2021 15:05:14 +0300 Subject: Add set_gradients method for JAX backend. (#278) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: RĂ©mi Flamary --- ot/backend.py | 16 ++++++++-------- test/test_backend.py | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 8f46900..2ed40af 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -287,16 +287,16 @@ class JaxBackend(Backend): return jnp.array(a).astype(type_as.dtype) def set_gradients(self, val, inputs, grads): - # no gradients for jax because it is functional + from jax.flatten_util import ravel_pytree + val, = jax.lax.stop_gradient((val,)) - # does not work - # from jax import custom_jvp - # @custom_jvp - # def f(*inputs): - # return val - # f.defjvps(*grads) - # return f(*inputs) + ravelled_inputs, _ = ravel_pytree(inputs) + ravelled_grads, _ = ravel_pytree(grads) + aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2 + aux = aux - jax.lax.stop_gradient(aux) + + val, = jax.tree_map(lambda z: z + aux, (val,)) return val def zeros(self, shape, type_as=None): diff --git a/test/test_backend.py b/test/test_backend.py index bc5b00c..cbfaf94 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -345,7 +345,8 @@ def test_gradients_backends(): rnd = np.random.RandomState(0) v = rnd.randn(10) - c = rnd.randn(1) + c = rnd.randn() + e = rnd.randn() if torch: @@ -362,3 +363,15 @@ def test_gradients_backends(): assert torch.equal(v2.grad, v2) assert torch.equal(c2.grad, c2) + + if jax: + nx = ot.backend.JaxBackend() + with jax.checking_leaks(): + def fun(a, b, d): + val = b * nx.sum(a ** 4) + d + return nx.set_gradients(val, (a, b, d), (a, b, 2 * d)) + grad_val = jax.grad(fun, argnums=(0, 1, 2))(v, c, e) + + np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4) + np.testing.assert_allclose(grad_val[0], v, atol=1e-4) + np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4) -- cgit v1.2.3