From 05905acdbf2e22bc4b3dc7556d0c6faba0786d23 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 24 Oct 2016 16:08:55 +0200 Subject: add barycenter and unmixing --- ot/bregman.py | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) (limited to 'ot') diff --git a/ot/bregman.py b/ot/bregman.py index 8b97e1e..1761029 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -89,3 +89,92 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): return np.dot(np.diag(u),np.dot(K,np.diag(v))) +def geometricBar(weights,alldistribT): + assert(len(weights)==alldistribT.shape[1]) + return np.exp(np.dot(np.log(alldistribT),weights.T)) + +def geometricMean(alldistribT): + return np.exp(np.mean(np.log(alldistribT),axis=1)) + +def projR(gamma,p): + #return np.dot(np.diag(p/np.maximum(np.sum(gamma,axis=1),1e-10)),gamma) + return np.multiply(gamma.T,p/np.maximum(np.sum(gamma,axis=1),1e-10)).T + +def projC(gamma,q): + #return (np.dot(np.diag(q/np.maximum(np.sum(gamma,axis=0),1e-10)),gamma.T)).T + return np.multiply(gamma,q/np.maximum(np.sum(gamma,axis=0),1e-10)) + + +def barycenter(A,M,reg, weights=None, numItermax = 1000, tol_error=1e-4,log=dict()): + """Compute the Regularizzed wassersteien barycenter of distributions A""" + + + if weights is None: + weights=np.ones(A.shape[1])/A.shape[1] + else: + assert(len(weights)==A.shape[1]) + + #compute Mmax once for all + #M = M/np.median(M) # suggested by G. Peyre + K = np.exp(-reg*M) + + cpt = 0 + err=1 + + UKv=np.dot(K,np.divide(A.T,np.sum(K,axis=0)).T) + u = (geometricMean(UKv)/UKv.T).T + + log = {'niter':0, 'all_err':[]} + + while (err>tol_error and cpttol_error and cpt