summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-07-23 21:28:30 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-07-23 21:28:30 +0200
commit09f3f640fc46ba4905d5508b704f2e5a90dda295 (patch)
treec0b1e2c3644209e5a6c49d676c2a9884fb0bd115 /ot
parent92292231a4d9661c399dbfd97b22d6f7f890f698 (diff)
fix issue 94 + add test
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index f39145d..70e4208 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -765,10 +765,14 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
cpt = cpt + 1
- # print('err=',err,' cpt=',cpt)
if log:
- log['logu'] = alpha / reg + np.log(u)
- log['logv'] = beta / reg + np.log(v)
+ if nbb:
+ alpha = alpha[:, None]
+ beta = beta[:, None]
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ log['logu'] = logu
+ log['logv'] = logv
log['alpha'] = alpha + reg * np.log(u)
log['beta'] = beta + reg * np.log(v)
log['warmstart'] = (log['alpha'], log['beta'])