summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/ot/backend.py b/ot/backend.py
index 0779243..74f8366 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -951,6 +951,14 @@ class Backend():
"""
raise NotImplementedError()
+ def detach(self, *args):
+ r"""
+ Detach tensors in arguments from the current graph.
+
+ See: https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -1279,6 +1287,11 @@ class NumpyBackend(Backend):
def transpose(self, a, axes=None):
return np.transpose(a, axes)
+ def detach(self, *args):
+ if len(args) == 1:
+ return args[0]
+ return args
+
class JaxBackend(Backend):
"""
@@ -1626,6 +1639,11 @@ class JaxBackend(Backend):
def transpose(self, a, axes=None):
return jnp.transpose(a, axes)
+ def detach(self, *args):
+ if len(args) == 1:
+ return jax.lax.stop_gradient((args[0],))[0]
+ return [jax.lax.stop_gradient((a,))[0] for a in args]
+
class TorchBackend(Backend):
"""
@@ -2072,6 +2090,11 @@ class TorchBackend(Backend):
axes = tuple(range(a.ndim)[::-1])
return a.permute(axes)
+ def detach(self, *args):
+ if len(args) == 1:
+ return args[0].detach()
+ return [a.detach() for a in args]
+
class CupyBackend(Backend): # pragma: no cover
"""
@@ -2443,6 +2466,11 @@ class CupyBackend(Backend): # pragma: no cover
def transpose(self, a, axes=None):
return cp.transpose(a, axes)
+ def detach(self, *args):
+ if len(args) == 1:
+ return args[0]
+ return args
+
class TensorflowBackend(Backend):
@@ -2826,3 +2854,8 @@ class TensorflowBackend(Backend):
def transpose(self, a, axes=None):
return tf.transpose(a, perm=axes)
+
+ def detach(self, *args):
+ if len(args) == 1:
+ return tf.stop_gradient(args[0])
+ return [tf.stop_gradient(a) for a in args]