diff options
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index a13345d..68be01c 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -416,6 +416,16 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war err=1 while loop: + + + uprev = u + vprev = v + + # sinkhorn update + v = b/(np.dot(K.T,u)+1e-16) + u = a/(np.dot(K,v)+1e-16) + + # remove numerical problems and store them in K if np.abs(u).max()>tau or np.abs(v).max()>tau: if nbb: @@ -428,12 +438,6 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war u,v = np.ones(na)/na,np.ones(nb)/nb K=get_K(alpha,beta) - uprev = u - vprev = v - - # sinkhorn update - v = b/np.dot(K.T,u) - u = a/np.dot(K,v) if cpt%print_period==0: # we can speed up the process by checking for the error only all the 10th iterations @@ -458,9 +462,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war loop=False - if (np.any(np.dot(K.T,u)==0) or - np.any(np.isnan(u)) or np.any(np.isnan(v)) or - np.any(np.isinf(u)) or np.any(np.isinf(v))): + if np.any(np.isnan(u)) or np.any(np.isnan(v)): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) |