diff options
-rw-r--r-- | RELEASES.md | 1 | ||||
-rw-r--r-- | ot/bregman.py | 4 |
2 files changed, 3 insertions, 2 deletions
diff --git a/RELEASES.md b/RELEASES.md index 02fddad..97f4c44 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -15,6 +15,7 @@ - Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466) - Faster Bures-Wasserstein distance with NumPy backend (PR #468) - Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471) +- Fix issue with ot.barycenter_stabilized when used with PyTorch tensors and log=True (RP #474) ## 0.9.0 diff --git a/ot/bregman.py b/ot/bregman.py index 4503ffc..29bcd58 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1898,8 +1898,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, "Or a larger absorption threshold `tau`.") if log: log['niter'] = ii - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['logu'] = nx.log(u + 1e-16) + log['logv'] = nx.log(v + 1e-16) return q, log else: return q |