summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-24 16:08:55 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-24 16:08:55 +0200
commit05905acdbf2e22bc4b3dc7556d0c6faba0786d23 (patch)
tree46496d8cec2a4241fa59a4fd91d5786c33265c41 /ot
parent70595f400acfa7b9b32eebc8885a4d4d27d7243d (diff)
add barycenter and unmixing
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py89
1 files changed, 89 insertions, 0 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 8b97e1e..1761029 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -89,3 +89,92 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
return np.dot(np.diag(u),np.dot(K,np.diag(v)))
+def geometricBar(weights,alldistribT):
+ assert(len(weights)==alldistribT.shape[1])
+ return np.exp(np.dot(np.log(alldistribT),weights.T))
+
+def geometricMean(alldistribT):
+ return np.exp(np.mean(np.log(alldistribT),axis=1))
+
+def projR(gamma,p):
+ #return np.dot(np.diag(p/np.maximum(np.sum(gamma,axis=1),1e-10)),gamma)
+ return np.multiply(gamma.T,p/np.maximum(np.sum(gamma,axis=1),1e-10)).T
+
+def projC(gamma,q):
+ #return (np.dot(np.diag(q/np.maximum(np.sum(gamma,axis=0),1e-10)),gamma.T)).T
+ return np.multiply(gamma,q/np.maximum(np.sum(gamma,axis=0),1e-10))
+
+
+def barycenter(A,M,reg, weights=None, numItermax = 1000, tol_error=1e-4,log=dict()):
+ """Compute the Regularizzed wassersteien barycenter of distributions A"""
+
+
+ if weights is None:
+ weights=np.ones(A.shape[1])/A.shape[1]
+ else:
+ assert(len(weights)==A.shape[1])
+
+ #compute Mmax once for all
+ #M = M/np.median(M) # suggested by G. Peyre
+ K = np.exp(-reg*M)
+
+ cpt = 0
+ err=1
+
+ UKv=np.dot(K,np.divide(A.T,np.sum(K,axis=0)).T)
+ u = (geometricMean(UKv)/UKv.T).T
+
+ log = {'niter':0, 'all_err':[]}
+
+ while (err>tol_error and cpt<numItermax):
+ cpt = cpt +1
+ UKv=u*np.dot(K,np.divide(A,np.dot(K,u)))
+ u = (u.T*geometricBar(weights,UKv)).T/UKv
+ if cpt%10==1:
+ err=np.sum(np.std(UKv,axis=1))
+ log['all_err'].append(err)
+
+ log['niter']=cpt
+ return geometricBar(weights,UKv)
+
+
+def unmixBregman(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log=dict()):
+ """
+ distrib : distribution to unmix
+ D : Dictionnary
+ M : Metric matrix in the space of the distributions to unmix
+ M0 : Metric matrix in the space of the 'abundance values' to solve for
+ h0 : prior on solution (generally uniform distribution)
+ reg,reg0 : transport regularizations
+ alpha : how much should we trust the prior ? ([0,1])
+ """
+
+ M = M/np.median(M)
+ K = np.exp(-reg*M)
+
+ M0 = M0/np.median(M0)
+ K0 = np.exp(-reg0*M0)
+ old = h0
+
+ err=1
+ cpt=0
+ log = {'niter':0, 'all_err':[]}
+
+ while (err>tol_error and cpt<numItermax):
+ K = projC(K,distrib)
+ K0 = projC(K0,h0)
+ new = np.sum(K0,axis=1)
+ inv_new = np.dot(D,new) # we recombine the current selection from dictionnary
+ other = np.sum(K,axis=1)
+ delta = np.exp(alpha*np.log(other)+(1-alpha)*np.log(inv_new)) # geometric interpolation
+ K = projR(K,delta)
+ K0 = np.dot(np.diag(np.dot(D.T,delta/inv_new)),K0)
+
+ err=np.linalg.norm(np.sum(K0,axis=1)-old)
+ old = new
+ log['all_err'].append(err)
+ cpt = cpt+1
+
+
+ log['niter']=cpt
+ return np.sum(K0,axis=1)