summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-07-24 14:35:58 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-07-24 14:35:58 +0200
commit5e3392a029e675c7e19f8b1723fcfdb9aa9142aa (patch)
treef66d6675d1605c3b726cdaa5bb5991d4e69fc3be /ot
parent603c0eee29db890b0092ea8c848473bf413e186f (diff)
cancel einsum
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py13
1 files changed, 5 insertions, 8 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 57cedb2..1873c46 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -358,14 +358,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
- if nbb:
- KtransposeU = np.einsum('ij,ik->jk',K,u)#np.dot(K.T, u)
- v = np.divide(b, KtransposeU)
- u = 1. / np.einsum('ij,jk->ik',Kp,v)#np.dot(Kp, v)
- else:
- KtransposeU = np.einsum('ij,i->j',K,u)#np.dot(K.T, u)
- v = np.divide(b, KtransposeU)
- u = 1. / np.einsum('ij,j->i',Kp,v)#np.dot(Kp, v)
+
+ KtransposeU = np.dot(K.T, u)
+ v = np.divide(b, KtransposeU)
+ u = 1. / np.dot(Kp, v)
+
if (np.any(KtransposeU == 0) or
np.any(np.isnan(u)) or np.any(np.isnan(v)) or