diff options
author | Oleksii Kachaiev <kachayev@gmail.com> | 2023-05-08 14:42:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-08 14:42:43 +0200 |
commit | 03341c6953c06608ba17d6bf7cd35666bc069988 (patch) | |
tree | fcb4fb8c3d4395d616e1b132306dfea42d5171ed | |
parent | f66299880f30240c55ffaa4ad5f85829f3b5b360 (diff) |
[MRG] Fix barycenter_stabilized with PyTorch and log set to True (#474)
* np -> nx for stabilized barycenters log
* Mention fix in RELEASES
-rw-r--r-- | RELEASES.md | 1 | ||||
-rw-r--r-- | ot/bregman.py | 4 |
2 files changed, 3 insertions, 2 deletions
diff --git a/RELEASES.md b/RELEASES.md index 02fddad..97f4c44 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -15,6 +15,7 @@ - Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466) - Faster Bures-Wasserstein distance with NumPy backend (PR #468) - Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471) +- Fix issue with ot.barycenter_stabilized when used with PyTorch tensors and log=True (RP #474) ## 0.9.0 diff --git a/ot/bregman.py b/ot/bregman.py index 4503ffc..29bcd58 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1898,8 +1898,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, "Or a larger absorption threshold `tau`.") if log: log['niter'] = ii - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['logu'] = nx.log(u + 1e-16) + log['logv'] = nx.log(v + 1e-16) return q, log else: return q |