diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-07-24 14:33:32 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-07-24 14:33:32 +0200 |
commit | a04112c69a62182c061d4b65e71ebb43c866d3e1 (patch) | |
tree | 4dda23c943bb3a06ac5e3a66bb0ff61a4e42c4a1 /ot | |
parent | bbe411775b3d5abb5d6fb525262cccce3f73d345 (diff) |
correction size
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index d2ade46..29ca9fd 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -358,9 +358,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, while (err > stopThr and cpt < numItermax): uprev = u vprev = v - KtransposeU = np.einsum('ij,i->j',K,u)#np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.einsum('ij,j->i',Kp,v)#np.dot(Kp, v) + if nbb: + KtransposeU = np.einsum('ij,i,k->jk',K,u)#np.dot(K.T, u) + v = np.divide(b, KtransposeU) + u = 1. / np.einsum('ij,jk->ik',Kp,v)#np.dot(Kp, v) + else: + KtransposeU = np.einsum('ij,i->j',K,u)#np.dot(K.T, u) + v = np.divide(b, KtransposeU) + u = 1. / np.einsum('ij,j->i',Kp,v)#np.dot(Kp, v) if (np.any(KtransposeU == 0) or np.any(np.isnan(u)) or np.any(np.isnan(v)) or |