summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2019-07-25 08:22:16 +0200
committerGitHub <noreply@github.com>2019-07-25 08:22:16 +0200
commitc64906b70d6059fd56836e8f760f4a836867d51b (patch)
tree96f505830c031022580d32d08d1d2081e9e45204 /ot
parent0063cb87a10293a24ad1c9483be121745958c24a (diff)
parenta507556b1901e16351c211e69b38d8d74ac2bc3d (diff)
Merge pull request #97 from hichamjanati/fix_mismatch_error_94
[MRG] Fix mismatch error in stabilized sinkhorn
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index f39145d..70e4208 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -765,10 +765,14 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
cpt = cpt + 1
- # print('err=',err,' cpt=',cpt)
if log:
- log['logu'] = alpha / reg + np.log(u)
- log['logv'] = beta / reg + np.log(v)
+ if nbb:
+ alpha = alpha[:, None]
+ beta = beta[:, None]
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ log['logu'] = logu
+ log['logv'] = logv
log['alpha'] = alpha + reg * np.log(u)
log['beta'] = beta + reg * np.log(v)
log['warmstart'] = (log['alpha'], log['beta'])