summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-13 10:53:29 +0200
committerLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-13 10:53:29 +0200
commit7bcae0fa807d67b36c44d5e0abeb45df8c65c3c6 (patch)
tree87fa8ad872e6bb4affd021a6a49a747749298306 /ot
parent92538abf67e5118431396c92caf82071866dcbe5 (diff)
update bregman file
- change commented prints to python3 compatible syntax - Correct issue that could cause the sinkhorn algo to stop with u and v containing nan/infinite numbers: - Assign uprev and vprev before changing u and v. - Then update u and v. - Then check if u and v contain nan, but ALSO infinite values. - if there are issues, then display error (with 2 r, not 3 :p) along with the iteration number (there may have errors at iteration 0)
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py50
1 files changed, 24 insertions, 26 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 27f8ff5..b06eaeb 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -102,31 +102,30 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
# we assume that no distances are null except those of the diagonal of distances
u = np.ones(Nini)/Nini
v = np.ones(Nfin)/Nfin
- uprev=np.zeros(Nini)
- vprev=np.zeros(Nini)
- #print reg
+ #print(reg)
K = np.exp(-M/reg)
- #print np.min(K)
+ #print(np.min(K))
Kp = np.dot(np.diag(1/a),K)
transp = K
cpt = 0
err=1
while (err>stopThr and cpt<numItermax):
- if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)):
- # we have reached the machine precision
- # come back to previous solution and quit loop
- print('Warning: numerical errrors')
- if cpt!=0:
- u = uprev
- v = vprev
- break
uprev = u
vprev = v
v = np.divide(b,np.dot(K.T,u))
u = 1./np.dot(Kp,v)
+ if (np.any(np.dot(K.T,u)==0) or
+ np.any(np.isnan(u)) or np.any(np.isnan(v)) or
+ np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ print('Warning: numerical errors at iteration', cpt)
+ u = uprev
+ v = vprev
+ break
if cpt%10==0:
# we can speed up the process by checking for the error only all the 10th iterations
transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
@@ -142,8 +141,8 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
if log:
log['u']=u
log['v']=v
-
- #print 'err=',err,' cpt=',cpt
+
+ #print('err=',err,' cpt=',cpt)
if log:
return u.reshape((-1,1))*K*v.reshape((1,-1)),log
else:
@@ -258,10 +257,8 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
else:
alpha,beta=warmstart
u,v = np.ones(na)/na,np.ones(nb)/nb
- uprev,vprev=np.zeros(na),np.zeros(nb)
-
- #print reg
+ #print(reg)
def get_K(alpha,beta):
@@ -272,7 +269,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
"""log space gamma computation"""
return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg+np.log(u.reshape((na,1)))+np.log(v.reshape((1,nb))))
- #print np.min(K)
+ #print(np.min(K))
K=get_K(alpha,beta)
transp = K
@@ -313,17 +310,18 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
loop=False
- if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)):
+ if (np.any(np.dot(K.T,u)==0) or
+ np.any(np.isnan(u)) or np.any(np.isnan(v)) or
+ np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- print('Warning: numerical errrors')
- if cpt!=0:
- u = uprev
- v = vprev
+ print('Warning: numerical errors at iteration', cpt)
+ u = uprev
+ v = vprev
break
cpt = cpt +1
- #print 'err=',err,' cpt=',cpt
+ #print('err=',err,' cpt=',cpt)
if log:
log['logu']=alpha/reg+np.log(u)
log['logv']=beta/reg+np.log(v)
@@ -456,7 +454,7 @@ def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInn
"""log space computation"""
return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg)
- #print np.min(K)
+ #print(np.min(K))
def get_reg(n): # exponential decreasing
return (epsilon0-reg)*np.exp(-n)+reg
@@ -491,7 +489,7 @@ def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInn
loop=False
cpt = cpt +1
- #print 'err=',err,' cpt=',cpt
+ #print('err=',err,' cpt=',cpt)
if log:
log['alpha']=alpha
log['beta']=beta