summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBastian Rieck <bastian@rieck.me>2022-01-19 08:39:04 +0100
committerGitHub <noreply@github.com>2022-01-19 08:39:04 +0100
commit263c5842664c1dff4f8e58111d6bddb33927539e (patch)
treedf644f0448d5d3ec6d7e065681ad2161947021f8
parent5861209f27fe8e022eca2ed2c8d0bb1da4a1146b (diff)
[MRG] Fix instantiation of `ValFunction` (which raises a warning with PyTorch) (#338)
* Not instantiating `ValFunction` `ValFunction` should not be instantiated since `autograd` functions are supposed to only ever use static methods. This solves a warning message raised by PyTorch. * Updated release information * Fixed PR number
-rw-r--r--RELEASES.md4
-rw-r--r--ot/backend.py2
2 files changed, 5 insertions, 1 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 9b92d97..c6ab9c3 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -6,6 +6,10 @@
- Better list of related examples in quick start guide with `minigallery` (PR #334)
+#### Closed issues
+
+- Bug in instantiating an `autograd` function (`ValFunction`, Issue #337, PR #338)
+
## 0.8.1.0
*December 2021*
diff --git a/ot/backend.py b/ot/backend.py
index 58b652b..6e0bc3d 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -1397,7 +1397,7 @@ class TorchBackend(Backend):
def set_gradients(self, val, inputs, grads):
- Func = self.ValFunction()
+ Func = self.ValFunction
res = Func.apply(val, grads, *inputs)