diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-03 14:51:47 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-03 14:51:47 +0200 |
commit | 4fba2c9c479e8f410a23ef24458effc29fc3f7f0 (patch) | |
tree | fa0f181e4e571204e2e38b59455403e28928c805 /ot/bregman.py | |
parent | 11239a9848e97201b3d4aa04224f3421b2c3974a (diff) |
debug bregman stabilized
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index a13345d..68be01c 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -416,6 +416,16 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war err=1 while loop: + + + uprev = u + vprev = v + + # sinkhorn update + v = b/(np.dot(K.T,u)+1e-16) + u = a/(np.dot(K,v)+1e-16) + + # remove numerical problems and store them in K if np.abs(u).max()>tau or np.abs(v).max()>tau: if nbb: @@ -428,12 +438,6 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war u,v = np.ones(na)/na,np.ones(nb)/nb K=get_K(alpha,beta) - uprev = u - vprev = v - - # sinkhorn update - v = b/np.dot(K.T,u) - u = a/np.dot(K,v) if cpt%print_period==0: # we can speed up the process by checking for the error only all the 10th iterations @@ -458,9 +462,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war loop=False - if (np.any(np.dot(K.T,u)==0) or - np.any(np.isnan(u)) or np.any(np.isnan(v)) or - np.any(np.isinf(u)) or np.any(np.isinf(v))): + if np.any(np.isnan(u)) or np.any(np.isnan(v)): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) |