diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-07-24 14:35:58 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-07-24 14:35:58 +0200 |
commit | 5e3392a029e675c7e19f8b1723fcfdb9aa9142aa (patch) | |
tree | f66d6675d1605c3b726cdaa5bb5991d4e69fc3be | |
parent | 603c0eee29db890b0092ea8c848473bf413e186f (diff) |
cancel einsum
-rw-r--r-- | ot/bregman.py | 13 |
1 files changed, 5 insertions, 8 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 57cedb2..1873c46 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -358,14 +358,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, while (err > stopThr and cpt < numItermax): uprev = u vprev = v - if nbb: - KtransposeU = np.einsum('ij,ik->jk',K,u)#np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.einsum('ij,jk->ik',Kp,v)#np.dot(Kp, v) - else: - KtransposeU = np.einsum('ij,i->j',K,u)#np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.einsum('ij,j->i',Kp,v)#np.dot(Kp, v) + + KtransposeU = np.dot(K.T, u) + v = np.divide(b, KtransposeU) + u = 1. / np.dot(Kp, v) + if (np.any(KtransposeU == 0) or np.any(np.isnan(u)) or np.any(np.isnan(v)) or |