diff options
author | Hicham Janati <hicham.janati@inria.fr> | 2019-07-23 21:28:30 +0200 |
---|---|---|
committer | Hicham Janati <hicham.janati@inria.fr> | 2019-07-23 21:28:30 +0200 |
commit | 09f3f640fc46ba4905d5508b704f2e5a90dda295 (patch) | |
tree | c0b1e2c3644209e5a6c49d676c2a9884fb0bd115 /ot/bregman.py | |
parent | 92292231a4d9661c399dbfd97b22d6f7f890f698 (diff) |
fix issue 94 + add test
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index f39145d..70e4208 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -765,10 +765,14 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, cpt = cpt + 1 - # print('err=',err,' cpt=',cpt) if log: - log['logu'] = alpha / reg + np.log(u) - log['logv'] = beta / reg + np.log(v) + if nbb: + alpha = alpha[:, None] + beta = beta[:, None] + logu = alpha / reg + np.log(u) + logv = beta / reg + np.log(v) + log['logu'] = logu + log['logv'] = logv log['alpha'] = alpha + reg * np.log(u) log['beta'] = beta + reg * np.log(v) log['warmstart'] = (log['alpha'], log['beta']) |