diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-28 11:22:23 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-28 11:22:23 +0200 |
commit | 3067c8873bf325808985453f0ac968d13435e032 (patch) | |
tree | d64a5820ec8b48706b8928bdd337ee42cf7ee0a0 /ot/bregman.py | |
parent | 8cd50c55f398cc371db2ef334c803dec99cc209a (diff) |
update bregman with doc
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 91 |
1 files changed, 76 insertions, 15 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 08f965b..b6cdf80 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -43,9 +43,9 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False) Max number of iterations stopThr: float, optional Stop threshol on error (>0) - verbose : int, optional + verbose : bool, optional Print information along iterations - log : int, optional + log : bool, optional record log if True @@ -96,7 +96,7 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False) cpt = 0 if log: - log={'loss':[]} + log={'err':[]} # we assume that no distances are null except those of the diagonal of distances u = np.ones(Nini)/Nini @@ -131,7 +131,7 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False) transp = np.dot(np.diag(u),np.dot(K,np.diag(v))) err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 if log: - log['loss'].append(err) + log['err'].append(err) if verbose: if cpt%200 ==0: @@ -146,10 +146,12 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False) def geometricBar(weights,alldistribT): + """return the weighted geometric mean of distributions""" assert(len(weights)==alldistribT.shape[1]) return np.exp(np.dot(np.log(alldistribT),weights.T)) def geometricMean(alldistribT): + """return the geometric mean of distributions""" return np.exp(np.mean(np.log(alldistribT),axis=1)) def projR(gamma,p): @@ -161,16 +163,66 @@ def projC(gamma,q): return np.multiply(gamma,q/np.maximum(np.sum(gamma,axis=0),1e-10)) -def barycenter(A,M,reg, weights=None, numItermax = 1000, tol_error=1e-4,log=dict()): - """Compute the Regularizzed wassersteien barycenter of distributions A""" +def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=False,log=False): + """Compute the entropic regularized wasserstein barycenter of distributions A + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + + Parameters + ---------- + A : np.ndarray (d,n) + n training distributions of size d + M : np.ndarray (ns,nt) + loss matrix for OT + reg: float + Regularization term >0 + numItermax: int, optional + Max number of iterations + stopThr: float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a: (d,) ndarray + Wasserstein barycenter + log: dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + + + + """ if weights is None: weights=np.ones(A.shape[1])/A.shape[1] else: assert(len(weights)==A.shape[1]) + + if log: + log={'err':[]} - #compute Mmax once for all #M = M/np.median(M) # suggested by G. Peyre K = np.exp(-M/reg) @@ -180,19 +232,28 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, tol_error=1e-4,log=dict UKv=np.dot(K,np.divide(A.T,np.sum(K,axis=0)).T) u = (geometricMean(UKv)/UKv.T).T - log['niter']=0 - log['all_err']=[] - - while (err>tol_error and cpt<numItermax): + while (err>stopThr and cpt<numItermax): cpt = cpt +1 UKv=u*np.dot(K,np.divide(A,np.dot(K,u))) u = (u.T*geometricBar(weights,UKv)).T/UKv + if cpt%10==1: err=np.sum(np.std(UKv,axis=1)) - log['all_err'].append(err) - - log['niter']=cpt - return geometricBar(weights,UKv),log + + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if cpt%200 ==0: + print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) + print('{:5d}|{:8e}|'.format(cpt,err)) + + if log: + log['niter']=cpt + return geometricBar(weights,UKv),log + else: + return geometricBar(weights,UKv) def unmix(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log=dict()): |