diff options
author | AdrienCorenflos <adrien.corenflos@gmail.com> | 2021-10-22 15:05:14 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-22 14:05:14 +0200 |
commit | d50d8145a5c0cf69d438b018cd5f1b914905e784 (patch) | |
tree | 391692ded33fbb5d2eca643218ff16ba98534edc /ot/backend.py | |
parent | 14c30d4cfac060ff0bf8c64d4c88c77df32aad86 (diff) |
Add set_gradients method for JAX backend. (#278)
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/ot/backend.py b/ot/backend.py index 8f46900..2ed40af 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -287,16 +287,16 @@ class JaxBackend(Backend): return jnp.array(a).astype(type_as.dtype) def set_gradients(self, val, inputs, grads): - # no gradients for jax because it is functional + from jax.flatten_util import ravel_pytree + val, = jax.lax.stop_gradient((val,)) - # does not work - # from jax import custom_jvp - # @custom_jvp - # def f(*inputs): - # return val - # f.defjvps(*grads) - # return f(*inputs) + ravelled_inputs, _ = ravel_pytree(inputs) + ravelled_grads, _ = ravel_pytree(grads) + aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2 + aux = aux - jax.lax.stop_gradient(aux) + + val, = jax.tree_map(lambda z: z + aux, (val,)) return val def zeros(self, shape, type_as=None): |