summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-11-16 13:07:38 +0100
committerGitHub <noreply@github.com>2021-11-16 13:07:38 +0100
commitf4b363d865a79c07248176c1e36990e0cb6814ea (patch)
tree37f51d94a01ae495e28cec55a78e1c9404ac48d9 /ot/backend.py
parent0c589912800b23609c730871c080ade0c807cdc1 (diff)
[WIP] Fix gradient scaling bug in emd (#310)
* orrect gradient bug in emd2 * small comment in test * deploy properly on tag release * subplot fail
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/ot/backend.py b/ot/backend.py
index a044f84..fa164c3 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -1203,7 +1203,7 @@ class TorchBackend(Backend):
@staticmethod
def backward(ctx, grad_output):
# the gradients are grad
- return (None, None) + ctx.grads
+ return (None, None) + tuple(g * grad_output for g in ctx.grads)
self.ValFunction = ValFunction