diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-24 16:08:55 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-24 16:08:55 +0200 |
commit | 05905acdbf2e22bc4b3dc7556d0c6faba0786d23 (patch) | |
tree | 46496d8cec2a4241fa59a4fd91d5786c33265c41 /ot | |
parent | 70595f400acfa7b9b32eebc8885a4d4d27d7243d (diff) |
add barycenter and unmixing
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 8b97e1e..1761029 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -89,3 +89,92 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): return np.dot(np.diag(u),np.dot(K,np.diag(v))) +def geometricBar(weights,alldistribT): + assert(len(weights)==alldistribT.shape[1]) + return np.exp(np.dot(np.log(alldistribT),weights.T)) + +def geometricMean(alldistribT): + return np.exp(np.mean(np.log(alldistribT),axis=1)) + +def projR(gamma,p): + #return np.dot(np.diag(p/np.maximum(np.sum(gamma,axis=1),1e-10)),gamma) + return np.multiply(gamma.T,p/np.maximum(np.sum(gamma,axis=1),1e-10)).T + +def projC(gamma,q): + #return (np.dot(np.diag(q/np.maximum(np.sum(gamma,axis=0),1e-10)),gamma.T)).T + return np.multiply(gamma,q/np.maximum(np.sum(gamma,axis=0),1e-10)) + + +def barycenter(A,M,reg, weights=None, numItermax = 1000, tol_error=1e-4,log=dict()): + """Compute the Regularizzed wassersteien barycenter of distributions A""" + + + if weights is None: + weights=np.ones(A.shape[1])/A.shape[1] + else: + assert(len(weights)==A.shape[1]) + + #compute Mmax once for all + #M = M/np.median(M) # suggested by G. Peyre + K = np.exp(-reg*M) + + cpt = 0 + err=1 + + UKv=np.dot(K,np.divide(A.T,np.sum(K,axis=0)).T) + u = (geometricMean(UKv)/UKv.T).T + + log = {'niter':0, 'all_err':[]} + + while (err>tol_error and cpt<numItermax): + cpt = cpt +1 + UKv=u*np.dot(K,np.divide(A,np.dot(K,u))) + u = (u.T*geometricBar(weights,UKv)).T/UKv + if cpt%10==1: + err=np.sum(np.std(UKv,axis=1)) + log['all_err'].append(err) + + log['niter']=cpt + return geometricBar(weights,UKv) + + +def unmixBregman(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log=dict()): + """ + distrib : distribution to unmix + D : Dictionnary + M : Metric matrix in the space of the distributions to unmix + M0 : Metric matrix in the space of the 'abundance values' to solve for + h0 : prior on solution (generally uniform distribution) + reg,reg0 : transport regularizations + alpha : how much should we trust the prior ? ([0,1]) + """ + + M = M/np.median(M) + K = np.exp(-reg*M) + + M0 = M0/np.median(M0) + K0 = np.exp(-reg0*M0) + old = h0 + + err=1 + cpt=0 + log = {'niter':0, 'all_err':[]} + + while (err>tol_error and cpt<numItermax): + K = projC(K,distrib) + K0 = projC(K0,h0) + new = np.sum(K0,axis=1) + inv_new = np.dot(D,new) # we recombine the current selection from dictionnary + other = np.sum(K,axis=1) + delta = np.exp(alpha*np.log(other)+(1-alpha)*np.log(inv_new)) # geometric interpolation + K = projR(K,delta) + K0 = np.dot(np.diag(np.dot(D.T,delta/inv_new)),K0) + + err=np.linalg.norm(np.sum(K0,axis=1)-old) + old = new + log['all_err'].append(err) + cpt = cpt+1 + + + log['niter']=cpt + return np.sum(K0,axis=1) |