diff options
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/ot/backend.py b/ot/backend.py index 0779243..74f8366 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -951,6 +951,14 @@ class Backend(): """ raise NotImplementedError() + def detach(self, *args): + r""" + Detach tensors in arguments from the current graph. + + See: https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1279,6 +1287,11 @@ class NumpyBackend(Backend): def transpose(self, a, axes=None): return np.transpose(a, axes) + def detach(self, *args): + if len(args) == 1: + return args[0] + return args + class JaxBackend(Backend): """ @@ -1626,6 +1639,11 @@ class JaxBackend(Backend): def transpose(self, a, axes=None): return jnp.transpose(a, axes) + def detach(self, *args): + if len(args) == 1: + return jax.lax.stop_gradient((args[0],))[0] + return [jax.lax.stop_gradient((a,))[0] for a in args] + class TorchBackend(Backend): """ @@ -2072,6 +2090,11 @@ class TorchBackend(Backend): axes = tuple(range(a.ndim)[::-1]) return a.permute(axes) + def detach(self, *args): + if len(args) == 1: + return args[0].detach() + return [a.detach() for a in args] + class CupyBackend(Backend): # pragma: no cover """ @@ -2443,6 +2466,11 @@ class CupyBackend(Backend): # pragma: no cover def transpose(self, a, axes=None): return cp.transpose(a, axes) + def detach(self, *args): + if len(args) == 1: + return args[0] + return args + class TensorflowBackend(Backend): @@ -2826,3 +2854,8 @@ class TensorflowBackend(Backend): def transpose(self, a, axes=None): return tf.transpose(a, perm=axes) + + def detach(self, *args): + if len(args) == 1: + return tf.stop_gradient(args[0]) + return [tf.stop_gradient(a) for a in args] |