diff options
author | Leo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr> | 2017-04-13 10:53:29 +0200 |
---|---|---|
committer | Leo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr> | 2017-04-13 10:53:29 +0200 |
commit | 7bcae0fa807d67b36c44d5e0abeb45df8c65c3c6 (patch) | |
tree | 87fa8ad872e6bb4affd021a6a49a747749298306 /ot | |
parent | 92538abf67e5118431396c92caf82071866dcbe5 (diff) |
update bregman file
- change commented prints to python3 compatible syntax
- Correct issue that could cause the sinkhorn algo to stop with u and v containing nan/infinite numbers:
- Assign uprev and vprev before changing u and v.
- Then update u and v.
- Then check if u and v contain nan, but ALSO infinite values.
- if there are issues, then display error (with 2 r, not 3 :p) along with the iteration number (there may have errors at iteration 0)
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 50 |
1 files changed, 24 insertions, 26 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 27f8ff5..b06eaeb 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -102,31 +102,30 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa # we assume that no distances are null except those of the diagonal of distances u = np.ones(Nini)/Nini v = np.ones(Nfin)/Nfin - uprev=np.zeros(Nini) - vprev=np.zeros(Nini) - #print reg + #print(reg) K = np.exp(-M/reg) - #print np.min(K) + #print(np.min(K)) Kp = np.dot(np.diag(1/a),K) transp = K cpt = 0 err=1 while (err>stopThr and cpt<numItermax): - if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)): - # we have reached the machine precision - # come back to previous solution and quit loop - print('Warning: numerical errrors') - if cpt!=0: - u = uprev - v = vprev - break uprev = u vprev = v v = np.divide(b,np.dot(K.T,u)) u = 1./np.dot(Kp,v) + if (np.any(np.dot(K.T,u)==0) or + np.any(np.isnan(u)) or np.any(np.isnan(v)) or + np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + print('Warning: numerical errors at iteration', cpt) + u = uprev + v = vprev + break if cpt%10==0: # we can speed up the process by checking for the error only all the 10th iterations transp = np.dot(np.diag(u),np.dot(K,np.diag(v))) @@ -142,8 +141,8 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa if log: log['u']=u log['v']=v - - #print 'err=',err,' cpt=',cpt + + #print('err=',err,' cpt=',cpt) if log: return u.reshape((-1,1))*K*v.reshape((1,-1)),log else: @@ -258,10 +257,8 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war else: alpha,beta=warmstart u,v = np.ones(na)/na,np.ones(nb)/nb - uprev,vprev=np.zeros(na),np.zeros(nb) - - #print reg + #print(reg) def get_K(alpha,beta): @@ -272,7 +269,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war """log space gamma computation""" return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg+np.log(u.reshape((na,1)))+np.log(v.reshape((1,nb)))) - #print np.min(K) + #print(np.min(K)) K=get_K(alpha,beta) transp = K @@ -313,17 +310,18 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war loop=False - if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)): + if (np.any(np.dot(K.T,u)==0) or + np.any(np.isnan(u)) or np.any(np.isnan(v)) or + np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errrors') - if cpt!=0: - u = uprev - v = vprev + print('Warning: numerical errors at iteration', cpt) + u = uprev + v = vprev break cpt = cpt +1 - #print 'err=',err,' cpt=',cpt + #print('err=',err,' cpt=',cpt) if log: log['logu']=alpha/reg+np.log(u) log['logv']=beta/reg+np.log(v) @@ -456,7 +454,7 @@ def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInn """log space computation""" return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg) - #print np.min(K) + #print(np.min(K)) def get_reg(n): # exponential decreasing return (epsilon0-reg)*np.exp(-n)+reg @@ -491,7 +489,7 @@ def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInn loop=False cpt = cpt +1 - #print 'err=',err,' cpt=',cpt + #print('err=',err,' cpt=',cpt) if log: log['alpha']=alpha log['beta']=beta |