summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
authorAdrienCorenflos <adrien.corenflos@gmail.com>2021-10-22 15:05:14 +0300
committerGitHub <noreply@github.com>2021-10-22 14:05:14 +0200
commitd50d8145a5c0cf69d438b018cd5f1b914905e784 (patch)
tree391692ded33fbb5d2eca643218ff16ba98534edc /ot/backend.py
parent14c30d4cfac060ff0bf8c64d4c88c77df32aad86 (diff)
Add set_gradients method for JAX backend. (#278)
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/ot/backend.py b/ot/backend.py
index 8f46900..2ed40af 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -287,16 +287,16 @@ class JaxBackend(Backend):
return jnp.array(a).astype(type_as.dtype)
def set_gradients(self, val, inputs, grads):
- # no gradients for jax because it is functional
+ from jax.flatten_util import ravel_pytree
+ val, = jax.lax.stop_gradient((val,))
- # does not work
- # from jax import custom_jvp
- # @custom_jvp
- # def f(*inputs):
- # return val
- # f.defjvps(*grads)
- # return f(*inputs)
+ ravelled_inputs, _ = ravel_pytree(inputs)
+ ravelled_grads, _ = ravel_pytree(grads)
+ aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2
+ aux = aux - jax.lax.stop_gradient(aux)
+
+ val, = jax.tree_map(lambda z: z + aux, (val,))
return val
def zeros(self, shape, type_as=None):