From ebbd1868ba03e46f029ed485e949f4936eded85c Mon Sep 17 00:00:00 2001 From: Leo gautheron Date: Tue, 18 Apr 2017 15:05:39 +0200 Subject: Performance improvement sinkhorn Doing the computation this way is equivalent and allows to reduce the space complexity required from O(max(a, b)^2) to O(a*b) (especially usefull to transport a small number of sources example to a lot of target) This also allows to decrease the computation time. --- ot/bregman.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index b06eaeb..1deccce 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -108,7 +108,8 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa K = np.exp(-M/reg) #print(np.min(K)) - Kp = np.dot(np.diag(1/a),K) + Kp = (1/a).reshape(-1, 1) * K + transp = K cpt = 0 err=1 @@ -128,7 +129,7 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa break if cpt%10==0: # we can speed up the process by checking for the error only all the 10th iterations - transp = np.dot(np.diag(u),np.dot(K,np.diag(v))) + transp = u.reshape(-1, 1) * (K * v) err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 if log: log['err'].append(err) -- cgit v1.2.3 From 691c97a033e54359e8c933e3bdd34bf5cf40151d Mon Sep 17 00:00:00 2001 From: Leo gautheron Date: Tue, 18 Apr 2017 16:04:03 +0200 Subject: little cleanup sinkhorn --- ot/bregman.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 1deccce..0453f14 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -94,8 +94,6 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa Nini = len(a) Nfin = len(b) - - cpt = 0 if log: log={'err':[]} @@ -109,8 +107,6 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa #print(np.min(K)) Kp = (1/a).reshape(-1, 1) * K - - transp = K cpt = 0 err=1 while (err>stopThr and cpt