summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--RELEASES.md1
-rw-r--r--ot/bregman.py4
2 files changed, 3 insertions, 2 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 02fddad..97f4c44 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -15,6 +15,7 @@
- Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466)
- Faster Bures-Wasserstein distance with NumPy backend (PR #468)
- Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471)
+- Fix issue with ot.barycenter_stabilized when used with PyTorch tensors and log=True (RP #474)
## 0.9.0
diff --git a/ot/bregman.py b/ot/bregman.py
index 4503ffc..29bcd58 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1898,8 +1898,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
"Or a larger absorption threshold `tau`.")
if log:
log['niter'] = ii
- log['logu'] = np.log(u + 1e-16)
- log['logv'] = np.log(v + 1e-16)
+ log['logu'] = nx.log(u + 1e-16)
+ log['logv'] = nx.log(v + 1e-16)
return q, log
else:
return q