summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-03 14:51:47 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-03 14:51:47 +0200
commit4fba2c9c479e8f410a23ef24458effc29fc3f7f0 (patch)
treefa0f181e4e571204e2e38b59455403e28928c805
parent11239a9848e97201b3d4aa04224f3421b2c3974a (diff)
debug bregman stabilized
-rw-r--r--ot/bregman.py20
1 files changed, 11 insertions, 9 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index a13345d..68be01c 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -416,6 +416,16 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
err=1
while loop:
+
+
+ uprev = u
+ vprev = v
+
+ # sinkhorn update
+ v = b/(np.dot(K.T,u)+1e-16)
+ u = a/(np.dot(K,v)+1e-16)
+
+
# remove numerical problems and store them in K
if np.abs(u).max()>tau or np.abs(v).max()>tau:
if nbb:
@@ -428,12 +438,6 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
u,v = np.ones(na)/na,np.ones(nb)/nb
K=get_K(alpha,beta)
- uprev = u
- vprev = v
-
- # sinkhorn update
- v = b/np.dot(K.T,u)
- u = a/np.dot(K,v)
if cpt%print_period==0:
# we can speed up the process by checking for the error only all the 10th iterations
@@ -458,9 +462,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
loop=False
- if (np.any(np.dot(K.T,u)==0) or
- np.any(np.isnan(u)) or np.any(np.isnan(v)) or
- np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if np.any(np.isnan(u)) or np.any(np.isnan(v)):
# we have reached the machine precision
# come back to previous solution and quit loop
print('Warning: numerical errors at iteration', cpt)