summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py34
1 files changed, 23 insertions, 11 deletions
diff --git a/ot/backend.py b/ot/backend.py
index d68f5cf..8f46900 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -389,6 +389,26 @@ class TorchBackend(Backend):
__name__ = 'torch'
__type__ = torch_type
+ def __init__(self):
+
+ from torch.autograd import Function
+
+ # define a function that takes inputs val and grads
+ # ad returns a val tensor with proper gradients
+ class ValFunction(Function):
+
+ @staticmethod
+ def forward(ctx, val, grads, *inputs):
+ ctx.grads = grads
+ return val
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # the gradients are grad
+ return (None, None) + ctx.grads
+
+ self.ValFunction = ValFunction
+
def to_numpy(self, a):
return a.cpu().detach().numpy()
@@ -399,20 +419,12 @@ class TorchBackend(Backend):
return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device)
def set_gradients(self, val, inputs, grads):
- from torch.autograd import Function
- # define a function that takes inputs and return val
- class ValFunction(Function):
- @staticmethod
- def forward(ctx, *inputs):
- return val
+ Func = self.ValFunction()
- @staticmethod
- def backward(ctx, grad_output):
- # the gradients are grad
- return grads
+ res = Func.apply(val, grads, *inputs)
- return ValFunction.apply(*inputs)
+ return res
def zeros(self, shape, type_as=None):
if type_as is None: