summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2017-04-18 16:23:48 +0200
committerGitHub <noreply@github.com>2017-04-18 16:23:48 +0200
commite30345b8592bfbe4260be658261c3ce7c03d56fa (patch)
tree9c39ed465aa358e97ae2991187d01347aac6f0c7
parentebade89480e6cc8f49e2dc5240e81540ff35880b (diff)
parent691c97a033e54359e8c933e3bdd34bf5cf40151d (diff)
Merge pull request #7 from aje/master
Performance improvement sinkhorn
-rw-r--r--ot/bregman.py7
1 files changed, 2 insertions, 5 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index b06eaeb..0453f14 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -94,8 +94,6 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
Nini = len(a)
Nfin = len(b)
-
- cpt = 0
if log:
log={'err':[]}
@@ -108,8 +106,7 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
K = np.exp(-M/reg)
#print(np.min(K))
- Kp = np.dot(np.diag(1/a),K)
- transp = K
+ Kp = (1/a).reshape(-1, 1) * K
cpt = 0
err=1
while (err>stopThr and cpt<numItermax):
@@ -128,7 +125,7 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
break
if cpt%10==0:
# we can speed up the process by checking for the error only all the 10th iterations
- transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
+ transp = u.reshape(-1, 1) * (K * v)
err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
if log:
log['err'].append(err)