From d693ac25988dd557cb1ee7fc96f3a656f7d4301c Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 2 Jun 2021 12:59:35 +0200 Subject: [WIP] Add Wasserstein GAN and debug memory leak (#254) * add example and debug memory leak * print stuff * speedup gallery * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * test cells * proper header gan exmaple * cleanup sections * last changes ? Co-authored-by: Alexandre Gramfort --- ot/backend.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) (limited to 'ot') diff --git a/ot/backend.py b/ot/backend.py index d68f5cf..8f46900 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -389,6 +389,26 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + def __init__(self): + + from torch.autograd import Function + + # define a function that takes inputs val and grads + # ad returns a val tensor with proper gradients + class ValFunction(Function): + + @staticmethod + def forward(ctx, val, grads, *inputs): + ctx.grads = grads + return val + + @staticmethod + def backward(ctx, grad_output): + # the gradients are grad + return (None, None) + ctx.grads + + self.ValFunction = ValFunction + def to_numpy(self, a): return a.cpu().detach().numpy() @@ -399,20 +419,12 @@ class TorchBackend(Backend): return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device) def set_gradients(self, val, inputs, grads): - from torch.autograd import Function - # define a function that takes inputs and return val - class ValFunction(Function): - @staticmethod - def forward(ctx, *inputs): - return val + Func = self.ValFunction() - @staticmethod - def backward(ctx, grad_output): - # the gradients are grad - return grads + res = Func.apply(val, grads, *inputs) - return ValFunction.apply(*inputs) + return res def zeros(self, shape, type_as=None): if type_as is None: -- cgit v1.2.3