summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-07-19 17:04:14 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-07-19 17:04:14 +0200
commit0d23718409b1f0ac41b9302d98ca3d1ab9577855 (patch)
tree3d24ffe71208d2cf6c91ae94c4c99542a0bb560a /ot
parent952503e02b1fc9bdf0811b937baacca57e4a98f1 (diff)
remove square in convergence check
Diffstat (limited to 'ot')
-rw-r--r--ot/unbalanced.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 50ec03c..f6c2d5f 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -371,8 +371,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.)
+ err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
@@ -498,8 +499,9 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \
- np.sum((v - vprev) ** 2) / np.sum((v) ** 2)
+ err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.)
+ err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose: