summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOleksii Kachaiev <kachayev@gmail.com>2023-05-08 14:42:43 +0200
committerGitHub <noreply@github.com>2023-05-08 14:42:43 +0200
commit03341c6953c06608ba17d6bf7cd35666bc069988 (patch)
treefcb4fb8c3d4395d616e1b132306dfea42d5171ed
parentf66299880f30240c55ffaa4ad5f85829f3b5b360 (diff)
[MRG] Fix barycenter_stabilized with PyTorch and log set to True (#474)
* np -> nx for stabilized barycenters log * Mention fix in RELEASES
-rw-r--r--RELEASES.md1
-rw-r--r--ot/bregman.py4
2 files changed, 3 insertions, 2 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 02fddad..97f4c44 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -15,6 +15,7 @@
- Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466)
- Faster Bures-Wasserstein distance with NumPy backend (PR #468)
- Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471)
+- Fix issue with ot.barycenter_stabilized when used with PyTorch tensors and log=True (RP #474)
## 0.9.0
diff --git a/ot/bregman.py b/ot/bregman.py
index 4503ffc..29bcd58 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1898,8 +1898,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
"Or a larger absorption threshold `tau`.")
if log:
log['niter'] = ii
- log['logu'] = np.log(u + 1e-16)
- log['logv'] = np.log(v + 1e-16)
+ log['logu'] = nx.log(u + 1e-16)
+ log['logv'] = nx.log(v + 1e-16)
return q, log
else:
return q