summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-20 12:12:15 +0200
committerLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-20 12:12:15 +0200
commit16f51f971607efab2c73958d207c582b389406c8 (patch)
tree299a4f6f13faf8545d2144767e9a7791098aacf8 /ot/bregman.py
parent48ec27d8e1c2599bd6d9015d15f4204b8116af28 (diff)
sinkhorn GPU implementation
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 0453f14..c46e5dc 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -112,9 +112,11 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
while (err>stopThr and cpt<numItermax):
uprev = u
vprev = v
- v = np.divide(b,np.dot(K.T,u))
+ KtransposeU = np.dot(K.T, u)
+ v = np.divide(b, KtransposeU)
u = 1./np.dot(Kp,v)
- if (np.any(np.dot(K.T,u)==0) or
+
+ if (np.any(KtransposeU==0) or
np.any(np.isnan(u)) or np.any(np.isnan(v)) or
np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision