summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-06-02 12:59:35 +0200
committerGitHub <noreply@github.com>2021-06-02 12:59:35 +0200
commitd693ac25988dd557cb1ee7fc96f3a656f7d4301c (patch)
tree9dc6d91b816b4b379684fbd12bf0e853f22374c0 /ot/backend.py
parent184f8f4f7ac78f1dd7f653496d2753211a4e3426 (diff)
[WIP] Add Wasserstein GAN and debug memory leak (#254)
* add example and debug memory leak * print stuff * speedup gallery * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * test cells * proper header gan exmaple * cleanup sections * last changes ? Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py34
1 files changed, 23 insertions, 11 deletions
diff --git a/ot/backend.py b/ot/backend.py
index d68f5cf..8f46900 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -389,6 +389,26 @@ class TorchBackend(Backend):
__name__ = 'torch'
__type__ = torch_type
+ def __init__(self):
+
+ from torch.autograd import Function
+
+ # define a function that takes inputs val and grads
+ # ad returns a val tensor with proper gradients
+ class ValFunction(Function):
+
+ @staticmethod
+ def forward(ctx, val, grads, *inputs):
+ ctx.grads = grads
+ return val
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # the gradients are grad
+ return (None, None) + ctx.grads
+
+ self.ValFunction = ValFunction
+
def to_numpy(self, a):
return a.cpu().detach().numpy()
@@ -399,20 +419,12 @@ class TorchBackend(Backend):
return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device)
def set_gradients(self, val, inputs, grads):
- from torch.autograd import Function
- # define a function that takes inputs and return val
- class ValFunction(Function):
- @staticmethod
- def forward(ctx, *inputs):
- return val
+ Func = self.ValFunction()
- @staticmethod
- def backward(ctx, grad_output):
- # the gradients are grad
- return grads
+ res = Func.apply(val, grads, *inputs)
- return ValFunction.apply(*inputs)
+ return res
def zeros(self, shape, type_as=None):
if type_as is None: