diff options
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 34 |
1 files changed, 23 insertions, 11 deletions
diff --git a/ot/backend.py b/ot/backend.py index d68f5cf..8f46900 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -389,6 +389,26 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + def __init__(self): + + from torch.autograd import Function + + # define a function that takes inputs val and grads + # ad returns a val tensor with proper gradients + class ValFunction(Function): + + @staticmethod + def forward(ctx, val, grads, *inputs): + ctx.grads = grads + return val + + @staticmethod + def backward(ctx, grad_output): + # the gradients are grad + return (None, None) + ctx.grads + + self.ValFunction = ValFunction + def to_numpy(self, a): return a.cpu().detach().numpy() @@ -399,20 +419,12 @@ class TorchBackend(Backend): return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device) def set_gradients(self, val, inputs, grads): - from torch.autograd import Function - # define a function that takes inputs and return val - class ValFunction(Function): - @staticmethod - def forward(ctx, *inputs): - return val + Func = self.ValFunction() - @staticmethod - def backward(ctx, grad_output): - # the gradients are grad - return grads + res = Func.apply(val, grads, *inputs) - return ValFunction.apply(*inputs) + return res def zeros(self, shape, type_as=None): if type_as is None: |