From 161d68a79bc528a0d87e421f67a419cd757c7fba Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Sun, 20 Oct 2019 22:55:21 +0200 Subject: fix loop counter in barycenter + precision of dual variables --- ot/unbalanced.py | 102 +++++++++++++++++++++++++++---------------------------- 1 file changed, 50 insertions(+), 52 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 978df08..23f6607 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -384,10 +384,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, fi = reg_m / (reg_m + reg) - cpt = 0 err = 1. - while (err > stopThr and cpt < numItermax): + for i in range(numItermax): uprev = u vprev = v @@ -401,28 +400,27 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % cpt) + warnings.warn('Numerical errors at iteration %s' % i) u = uprev v = vprev break - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) - err = 0.5 * (err_u + err_v) - if log: - log['err'].append(err) + + err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) + err = 0.5 * (err_u + err_v) + if log: + log['err'].append(err) if verbose: - if cpt % 200 == 0: + if i % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt += 1 + print('{:5d}|{:8e}|'.format(i, err)) + if err < stopThr: + break if log: - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['logu'] = np.log(u + 1e-300) + log['logv'] = np.log(v + 1e-300) if n_hists: # return only loss res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) @@ -747,8 +745,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, alpha = np.zeros(dim) beta = np.zeros(dim) q = np.ones(dim) / dim - while (err > stopThr and cpt < numItermax): - qprev = q + for i in range(numItermax): + qprev = q.copy() Kv = K.dot(v) f_alpha = np.exp(- alpha / (reg + reg_m)) f_beta = np.exp(- beta / (reg + reg_m)) @@ -777,7 +775,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, warnings.warn('Numerical errors at iteration %s' % cpt) q = qprev break - if (cpt % 10 == 0 and not absorbing) or cpt == 0: + if (i % 10 == 0 and not absorbing) or i == 0: # we can speed up the process by checking for the error only all # the 10th iterations err = abs(q - qprev).max() / max(abs(q).max(), @@ -785,20 +783,21 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, if log: log['err'].append(err) if verbose: - if cpt % 50 == 0: + if i % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(i, err)) + if err < stopThr: + break - cpt += 1 if err > stopThr: warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + "Try a larger entropy `reg` or a lower mass `reg_m`." + "Or a larger absorption threshold `tau`.") if log: - log['niter'] = cpt - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['niter'] = i + log['logu'] = np.log(u + 1e-300) + log['logv'] = np.log(v + 1e-300) return q, log else: return q @@ -882,15 +881,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, fi = reg_m / (reg_m + reg) - v = np.ones((dim, n_hists)) / dim - u = np.ones((dim, 1)) / dim - - cpt = 0 + v = np.ones((dim, n_hists)) + u = np.ones((dim, 1)) + q = np.ones(dim) err = 1. - while (err > stopThr and cpt < numItermax): - uprev = u - vprev = v + for i in range(numItermax): + uprev = u.copy() + vprev = v.copy() + qprev = q.copy() Kv = K.dot(v) u = (A / Kv) ** fi @@ -905,31 +904,30 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % cpt) + warnings.warn('Numerical errors at iteration %s' % i) u = uprev v = vprev + q = qprev break - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err_u = abs(u - uprev).max() - err_u /= max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() - err_v /= max(abs(v).max(), abs(vprev).max(), 1.) - err = 0.5 * (err_u + err_v) - if log: - log['err'].append(err) - if verbose: - if cpt % 50 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + # compute change in barycenter + err = abs(q - qprev).max() + err /= max(abs(q).max(), abs(qprev).max(), 1.) + if log: + log['err'].append(err) + # if barycenter did not change + at least 10 iterations - stop + if err < stopThr and i > 10: + break + + if verbose: + if i % 10 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(i, err)) - cpt += 1 if log: - log['niter'] = cpt - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['niter'] = i + log['logu'] = np.log(u + 1e-300) + log['logv'] = np.log(v + 1e-300) return q, log else: return q -- cgit v1.2.3