diff options
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r-- | ot/unbalanced.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 50ec03c..f6c2d5f 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -371,8 +371,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ - np.sum((v - vprev)**2) / np.sum((v)**2) + err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.) + err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.) + err = 0.5 * (err_u + err_v) if log: log['err'].append(err) if verbose: @@ -498,8 +499,9 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \ - np.sum((v - vprev) ** 2) / np.sum((v) ** 2) + err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.) + err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.) + err = 0.5 * (err_u + err_v) if log: log['err'].append(err) if verbose: |