diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2019-10-24 14:39:12 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-10-24 14:39:12 +0200 |
commit | 65ca6bfde77dd11d84cbd151fe9ff98454f8e206 (patch) | |
tree | d2c22473d9dbc48ad16ce2b95863eaa2ae6242b3 /ot | |
parent | 5e70a77fbb2feec513f21c9ef65dcc535329ace6 (diff) | |
parent | 161d68a79bc528a0d87e421f67a419cd757c7fba (diff) |
Merge pull request #106 from hichamjanati/fix-weighted-bar
MRG: Forgotten weights arg in barycenter funcs
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 6 | ||||
-rw-r--r-- | ot/unbalanced.py | 105 |
2 files changed, 57 insertions, 54 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 2cd832b..ba5c7ba 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1037,11 +1037,13 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, """ if method.lower() == 'sinkhorn': - return barycenter_sinkhorn(A, M, reg, numItermax=numItermax, + return barycenter_sinkhorn(A, M, reg, weights=weights, + numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return barycenter_stabilized(A, M, reg, numItermax=numItermax, + return barycenter_stabilized(A, M, reg, weights=weights, + numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: diff --git a/ot/unbalanced.py b/ot/unbalanced.py index d516dfc..23f6607 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -384,10 +384,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, fi = reg_m / (reg_m + reg) - cpt = 0 err = 1. - while (err > stopThr and cpt < numItermax): + for i in range(numItermax): uprev = u vprev = v @@ -401,28 +400,27 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % cpt) + warnings.warn('Numerical errors at iteration %s' % i) u = uprev v = vprev break - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) - err = 0.5 * (err_u + err_v) - if log: - log['err'].append(err) + + err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) + err = 0.5 * (err_u + err_v) + if log: + log['err'].append(err) if verbose: - if cpt % 200 == 0: + if i % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt += 1 + print('{:5d}|{:8e}|'.format(i, err)) + if err < stopThr: + break if log: - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['logu'] = np.log(u + 1e-300) + log['logv'] = np.log(v + 1e-300) if n_hists: # return only loss res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) @@ -747,8 +745,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, alpha = np.zeros(dim) beta = np.zeros(dim) q = np.ones(dim) / dim - while (err > stopThr and cpt < numItermax): - qprev = q + for i in range(numItermax): + qprev = q.copy() Kv = K.dot(v) f_alpha = np.exp(- alpha / (reg + reg_m)) f_beta = np.exp(- beta / (reg + reg_m)) @@ -777,7 +775,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, warnings.warn('Numerical errors at iteration %s' % cpt) q = qprev break - if (cpt % 10 == 0 and not absorbing) or cpt == 0: + if (i % 10 == 0 and not absorbing) or i == 0: # we can speed up the process by checking for the error only all # the 10th iterations err = abs(q - qprev).max() / max(abs(q).max(), @@ -785,20 +783,21 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, if log: log['err'].append(err) if verbose: - if cpt % 50 == 0: + if i % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(i, err)) + if err < stopThr: + break - cpt += 1 if err > stopThr: warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + "Try a larger entropy `reg` or a lower mass `reg_m`." + "Or a larger absorption threshold `tau`.") if log: - log['niter'] = cpt - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['niter'] = i + log['logu'] = np.log(u + 1e-300) + log['logv'] = np.log(v + 1e-300) return q, log else: return q @@ -882,15 +881,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, fi = reg_m / (reg_m + reg) - v = np.ones((dim, n_hists)) / dim - u = np.ones((dim, 1)) / dim - - cpt = 0 + v = np.ones((dim, n_hists)) + u = np.ones((dim, 1)) + q = np.ones(dim) err = 1. - while (err > stopThr and cpt < numItermax): - uprev = u - vprev = v + for i in range(numItermax): + uprev = u.copy() + vprev = v.copy() + qprev = q.copy() Kv = K.dot(v) u = (A / Kv) ** fi @@ -905,31 +904,30 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % cpt) + warnings.warn('Numerical errors at iteration %s' % i) u = uprev v = vprev + q = qprev break - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err_u = abs(u - uprev).max() - err_u /= max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() - err_v /= max(abs(v).max(), abs(vprev).max(), 1.) - err = 0.5 * (err_u + err_v) - if log: - log['err'].append(err) - if verbose: - if cpt % 50 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + # compute change in barycenter + err = abs(q - qprev).max() + err /= max(abs(q).max(), abs(qprev).max(), 1.) + if log: + log['err'].append(err) + # if barycenter did not change + at least 10 iterations - stop + if err < stopThr and i > 10: + break + + if verbose: + if i % 10 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(i, err)) - cpt += 1 if log: - log['niter'] = cpt - log['logu'] = np.log(u + 1e-16) - log['logv'] = np.log(v + 1e-16) + log['niter'] = i + log['logu'] = np.log(u + 1e-300) + log['logv'] = np.log(v + 1e-300) return q, log else: return q @@ -1002,12 +1000,14 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, if method.lower() == 'sinkhorn': return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, + weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return barycenter_unbalanced_stabilized(A, M, reg, reg_m, + weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, @@ -1015,6 +1015,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') return barycenter_unbalanced(A, M, reg, reg_m, + weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) |