summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
Diffstat (limited to 'ot')
-rw-r--r--ot/unbalanced.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 50ec03c..f6c2d5f 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -371,8 +371,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.)
+ err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
@@ -498,8 +499,9 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \
- np.sum((v - vprev) ** 2) / np.sum((v) ** 2)
+ err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.)
+ err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose: