diff options
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/ot/backend.py b/ot/backend.py index a044f84..fa164c3 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1203,7 +1203,7 @@ class TorchBackend(Backend): @staticmethod def backward(ctx, grad_output): # the gradients are grad - return (None, None) + ctx.grads + return (None, None) + tuple(g * grad_output for g in ctx.grads) self.ValFunction = ValFunction |