diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-06-02 12:59:35 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-02 12:59:35 +0200 |
commit | d693ac25988dd557cb1ee7fc96f3a656f7d4301c (patch) | |
tree | 9dc6d91b816b4b379684fbd12bf0e853f22374c0 /ot | |
parent | 184f8f4f7ac78f1dd7f653496d2753211a4e3426 (diff) |
[WIP] Add Wasserstein GAN and debug memory leak (#254)
* add example and debug memory leak
* print stuff
* speedup gallery
* Apply suggestions from code review
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* test cells
* proper header gan exmaple
* cleanup sections
* last changes ?
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot')
-rw-r--r-- | ot/backend.py | 34 |
1 files changed, 23 insertions, 11 deletions
diff --git a/ot/backend.py b/ot/backend.py index d68f5cf..8f46900 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -389,6 +389,26 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + def __init__(self): + + from torch.autograd import Function + + # define a function that takes inputs val and grads + # ad returns a val tensor with proper gradients + class ValFunction(Function): + + @staticmethod + def forward(ctx, val, grads, *inputs): + ctx.grads = grads + return val + + @staticmethod + def backward(ctx, grad_output): + # the gradients are grad + return (None, None) + ctx.grads + + self.ValFunction = ValFunction + def to_numpy(self, a): return a.cpu().detach().numpy() @@ -399,20 +419,12 @@ class TorchBackend(Backend): return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device) def set_gradients(self, val, inputs, grads): - from torch.autograd import Function - # define a function that takes inputs and return val - class ValFunction(Function): - @staticmethod - def forward(ctx, *inputs): - return val + Func = self.ValFunction() - @staticmethod - def backward(ctx, grad_output): - # the gradients are grad - return grads + res = Func.apply(val, grads, *inputs) - return ValFunction.apply(*inputs) + return res def zeros(self, shape, type_as=None): if type_as is None: |