summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-18 15:05:39 +0200
committerLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-18 15:05:39 +0200
commitebbd1868ba03e46f029ed485e949f4936eded85c (patch)
tree7768056a15aae2bc40d7f31ec419d398ba738f6e /ot/bregman.py
parent7bcae0fa807d67b36c44d5e0abeb45df8c65c3c6 (diff)
Performance improvement sinkhorn
Doing the computation this way is equivalent and allows to reduce the space complexity required from O(max(a, b)^2) to O(a*b) (especially usefull to transport a small number of sources example to a lot of target) This also allows to decrease the computation time.
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index b06eaeb..1deccce 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -108,7 +108,8 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
K = np.exp(-M/reg)
#print(np.min(K))
- Kp = np.dot(np.diag(1/a),K)
+ Kp = (1/a).reshape(-1, 1) * K
+
transp = K
cpt = 0
err=1
@@ -128,7 +129,7 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
break
if cpt%10==0:
# we can speed up the process by checking for the error only all the 10th iterations
- transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
+ transp = u.reshape(-1, 1) * (K * v)
err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
if log:
log['err'].append(err)