diff options
author | Gard Spreemann <gspr@nonempty.org> | 2020-07-09 08:50:11 +0200 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2020-07-09 08:50:11 +0200 |
commit | 616c6039614e96b6cfd88eb7db25ce11c7302c30 (patch) | |
tree | a745ec24604f660b256dbaee214affa497c9c21e /ot/unbalanced.py | |
parent | d62f05daf348b4e554056f298c66cbd64f5e3c6e (diff) | |
parent | a16b9471d7114ec08977479b7249efe747702b97 (diff) |
Merge branch 'dfsg/latest' into debian/sid
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r-- | ot/unbalanced.py | 107 |
1 files changed, 54 insertions, 53 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py index d516dfc..e37f10c 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Regularized Unbalanced OT +Regularized Unbalanced OT solvers """ # Author: Hicham Janati <hicham.janati@inria.fr> @@ -384,10 +384,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, fi = reg_m / (reg_m + reg) - cpt = 0 err = 1. - while (err > stopThr and cpt < numItermax): + for i in range(numItermax): uprev = u vprev = v @@ -401,28 +400,27 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % cpt) + warnings.warn('Numerical errors at iteration %s' % i) u = uprev v = vprev break - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) - err = 0.5 * (err_u + err_v) - if log: - log['err'].append(err) + + err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) + err = 0.5 * (err_u + err_v) + if log: + log['err'].append(err) if verbose: - if cpt % 200 == 0: + if i % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt += 1 + print('{:5d}|{:8e}|'.format(i, err)) + if err < stopThr: + break if log: - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['logu'] = np.log(u + 1e-300) + log['logv'] = np.log(v + 1e-300) if n_hists: # return only loss res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) @@ -747,8 +745,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, alpha = np.zeros(dim) beta = np.zeros(dim) q = np.ones(dim) / dim - while (err > stopThr and cpt < numItermax): - qprev = q + for i in range(numItermax): + qprev = q.copy() Kv = K.dot(v) f_alpha = np.exp(- alpha / (reg + reg_m)) f_beta = np.exp(- beta / (reg + reg_m)) @@ -777,7 +775,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, warnings.warn('Numerical errors at iteration %s' % cpt) q = qprev break - if (cpt % 10 == 0 and not absorbing) or cpt == 0: + if (i % 10 == 0 and not absorbing) or i == 0: # we can speed up the process by checking for the error only all # the 10th iterations err = abs(q - qprev).max() / max(abs(q).max(), @@ -785,20 +783,21 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, if log: log['err'].append(err) if verbose: - if cpt % 50 == 0: + if i % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(i, err)) + if err < stopThr: + break - cpt += 1 if err > stopThr: warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + "Try a larger entropy `reg` or a lower mass `reg_m`." + "Or a larger absorption threshold `tau`.") if log: - log['niter'] = cpt - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['niter'] = i + log['logu'] = np.log(u + 1e-300) + log['logv'] = np.log(v + 1e-300) return q, log else: return q @@ -882,15 +881,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, fi = reg_m / (reg_m + reg) - v = np.ones((dim, n_hists)) / dim - u = np.ones((dim, 1)) / dim - - cpt = 0 + v = np.ones((dim, n_hists)) + u = np.ones((dim, 1)) + q = np.ones(dim) err = 1. - while (err > stopThr and cpt < numItermax): - uprev = u - vprev = v + for i in range(numItermax): + uprev = u.copy() + vprev = v.copy() + qprev = q.copy() Kv = K.dot(v) u = (A / Kv) ** fi @@ -905,31 +904,30 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % cpt) + warnings.warn('Numerical errors at iteration %s' % i) u = uprev v = vprev + q = qprev break - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err_u = abs(u - uprev).max() - err_u /= max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() - err_v /= max(abs(v).max(), abs(vprev).max(), 1.) - err = 0.5 * (err_u + err_v) - if log: - log['err'].append(err) - if verbose: - if cpt % 50 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + # compute change in barycenter + err = abs(q - qprev).max() + err /= max(abs(q).max(), abs(qprev).max(), 1.) + if log: + log['err'].append(err) + # if barycenter did not change + at least 10 iterations - stop + if err < stopThr and i > 10: + break + + if verbose: + if i % 10 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(i, err)) - cpt += 1 if log: - log['niter'] = cpt - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['niter'] = i + log['logu'] = np.log(u + 1e-300) + log['logv'] = np.log(v + 1e-300) return q, log else: return q @@ -1002,12 +1000,14 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, if method.lower() == 'sinkhorn': return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, + weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return barycenter_unbalanced_stabilized(A, M, reg, reg_m, + weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, @@ -1015,6 +1015,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') return barycenter_unbalanced(A, M, reg, reg_m, + weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) |