diff options
author | Bastian Rieck <bastian@rieck.me> | 2022-01-19 08:39:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-19 08:39:04 +0100 |
commit | 263c5842664c1dff4f8e58111d6bddb33927539e (patch) | |
tree | df644f0448d5d3ec6d7e065681ad2161947021f8 | |
parent | 5861209f27fe8e022eca2ed2c8d0bb1da4a1146b (diff) |
[MRG] Fix instantiation of `ValFunction` (which raises a warning with PyTorch) (#338)
* Not instantiating `ValFunction`
`ValFunction` should not be instantiated since `autograd` functions are
supposed to only ever use static methods. This solves a warning message
raised by PyTorch.
* Updated release information
* Fixed PR number
-rw-r--r-- | RELEASES.md | 4 | ||||
-rw-r--r-- | ot/backend.py | 2 |
2 files changed, 5 insertions, 1 deletions
diff --git a/RELEASES.md b/RELEASES.md index 9b92d97..c6ab9c3 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,10 @@ - Better list of related examples in quick start guide with `minigallery` (PR #334) +#### Closed issues + +- Bug in instantiating an `autograd` function (`ValFunction`, Issue #337, PR #338) + ## 0.8.1.0 *December 2021* diff --git a/ot/backend.py b/ot/backend.py index 58b652b..6e0bc3d 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1397,7 +1397,7 @@ class TorchBackend(Backend): def set_gradients(self, val, inputs, grads): - Func = self.ValFunction() + Func = self.ValFunction res = Func.apply(val, grads, *inputs) |