summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py232
1 files changed, 116 insertions, 116 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 2d82ae4..a770c5a 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -14,21 +14,21 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
-
+
s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
+
+ \gamma^T 1= b
+
\gamma\geq 0
where :
-
+
- M is the (ns,nt) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)
-
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
-
-
+
+
Parameters
----------
a : np.ndarray (ns,)
@@ -36,79 +36,79 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
b : np.ndarray (nt,)
samples in the target domain
M : np.ndarray (ns,nt)
- loss matrix
- reg: float
+ loss matrix
+ reg : float
Regularization term >0
- numItermax: int, optional
+ numItermax : int, optional
Max number of iterations
- stopThr: float, optional
+ stopThr : float, optional
Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
- record log if True
-
-
+ record log if True
+
+
Returns
-------
- gamma: (ns x nt) ndarray
+ gamma : (ns x nt) ndarray
Optimal transportation matrix for the given parameters
- log: dict
- log dictionary return only if log==True in parameters
+ log : dict
+ log dictionary return only if log==True in parameters
Examples
--------
-
+
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.sinkhorn(a,b,M,1)
array([[ 0.36552929, 0.13447071],
[ 0.13447071, 0.36552929]])
-
-
+
+
References
----------
-
+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
-
-
+
+
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
-
- """
-
+
+ """
+
a=np.asarray(a,dtype=np.float64)
b=np.asarray(b,dtype=np.float64)
M=np.asarray(M,dtype=np.float64)
-
+
if len(a)==0:
a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
if len(b)==0:
- b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
-
+ b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+
# init data
Nini = len(a)
Nfin = len(b)
-
-
+
+
cpt = 0
if log:
log={'err':[]}
-
+
# we assume that no distances are null except those of the diagonal of distances
u = np.ones(Nini)/Nini
- v = np.ones(Nfin)/Nfin
+ v = np.ones(Nfin)/Nfin
uprev=np.zeros(Nini)
vprev=np.zeros(Nini)
#print reg
-
+
K = np.exp(-M/reg)
#print np.min(K)
-
+
Kp = np.dot(np.diag(1/a),K)
transp = K
cpt = 0
@@ -120,10 +120,10 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
print('Warning: numerical errrors')
if cpt!=0:
u = uprev
- v = vprev
+ v = vprev
break
uprev = u
- vprev = v
+ vprev = v
v = np.divide(b,np.dot(K.T,u))
u = 1./np.dot(Kp,v)
if cpt%10==0:
@@ -131,14 +131,14 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
if log:
- log['err'].append(err)
-
+ log['err'].append(err)
+
if verbose:
if cpt%200 ==0:
print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
print('{:5d}|{:8e}|'.format(cpt,err))
cpt = cpt +1
- #print 'err=',err,' cpt=',cpt
+ #print 'err=',err,' cpt=',cpt
if log:
return np.dot(np.diag(u),np.dot(K,np.diag(v))),log
else:
@@ -147,12 +147,12 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
def geometricBar(weights,alldistribT):
"""return the weighted geometric mean of distributions"""
- assert(len(weights)==alldistribT.shape[1])
- return np.exp(np.dot(np.log(alldistribT),weights.T))
+ assert(len(weights)==alldistribT.shape[1])
+ return np.exp(np.dot(np.log(alldistribT),weights.T))
-def geometricMean(alldistribT):
+def geometricMean(alldistribT):
"""return the geometric mean of distributions"""
- return np.exp(np.mean(np.log(alldistribT),axis=1))
+ return np.exp(np.mean(np.log(alldistribT),axis=1))
def projR(gamma,p):
#return np.dot(np.diag(p/np.maximum(np.sum(gamma,axis=1),1e-10)),gamma)
@@ -161,65 +161,65 @@ def projR(gamma,p):
def projC(gamma,q):
#return (np.dot(np.diag(q/np.maximum(np.sum(gamma,axis=0),1e-10)),gamma.T)).T
return np.multiply(gamma,q/np.maximum(np.sum(gamma,axis=0),1e-10))
-
+
def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=False,log=False):
"""Compute the entropic regularized wasserstein barycenter of distributions A
The function solves the following optimization problem:
-
+
.. math::
\mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
-
+
where :
-
+
- :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
-
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
-
+
Parameters
----------
A : np.ndarray (d,n)
- n training distributions of size d
+ n training distributions of size d
M : np.ndarray (d,d)
- loss matrix for OT
- reg: float
+ loss matrix for OT
+ reg : float
Regularization term >0
- numItermax: int, optional
+ numItermax : int, optional
Max number of iterations
- stopThr: float, optional
+ stopThr : float, optional
Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
- record log if True
-
-
+ record log if True
+
+
Returns
-------
- a: (d,) ndarray
+ a : (d,) ndarray
Wasserstein barycenter
- log: dict
- log dictionary return only if log==True in parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
-
References
----------
-
+
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
-
-
-
+
+
+
"""
-
-
+
+
if weights is None:
weights=np.ones(A.shape[1])/A.shape[1]
else:
assert(len(weights)==A.shape[1])
-
+
if log:
log={'err':[]}
@@ -231,130 +231,130 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=Fa
UKv=np.dot(K,np.divide(A.T,np.sum(K,axis=0)).T)
u = (geometricMean(UKv)/UKv.T).T
-
+
while (err>stopThr and cpt<numItermax):
cpt = cpt +1
UKv=u*np.dot(K,np.divide(A,np.dot(K,u)))
u = (u.T*geometricBar(weights,UKv)).T/UKv
-
+
if cpt%10==1:
err=np.sum(np.std(UKv,axis=1))
-
+
# log and verbose print
if log:
log['err'].append(err)
-
+
if verbose:
if cpt%200 ==0:
print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
- print('{:5d}|{:8e}|'.format(cpt,err))
+ print('{:5d}|{:8e}|'.format(cpt,err))
if log:
log['niter']=cpt
return geometricBar(weights,UKv),log
else:
return geometricBar(weights,UKv)
-
+
def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=False,log=False):
"""
Compute the unmixing of an observation with a given dictionary using Wasserstein distance
-
+
The function solve the following optimization problem:
-
+
.. math::
\mathbf{h} = arg\min_\mathbf{h} (1- \\alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\\alpha W_{M0,reg0}(\mathbf{h}_0,\mathbf{h})
where :
-
+
- :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}` is an observed distribution, :math:`\mathbf{h}_0` is aprior on unmixing
+ - :math:`\mathbf{a}` is an observed distribution, :math:`\mathbf{h}_0` is aprior on unmixing
- reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT data fitting
- reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix for regularization
- - :math:`\\alpha`weight data fitting and regularization
-
+ - :math:`\\alpha`weight data fitting and regularization
+
The optimization problem is solved suing the algorithm described in [4]
-
-
+
+
Parameters
----------
a : np.ndarray (d)
observed distribution
D : np.ndarray (d,n)
- dictionary matrix
+ dictionary matrix
M : np.ndarray (d,d)
- loss matrix
+ loss matrix
M0 : np.ndarray (n,n)
- loss matrix
+ loss matrix
h0 : np.ndarray (n,)
- prior on h
- reg: float
+ prior on h
+ reg : float
Regularization term >0 (Wasserstein data fitting)
- reg0: float
- Regularization term >0 (Wasserstein reg with h0)
- alpha: float
+ reg0 : float
+ Regularization term >0 (Wasserstein reg with h0)
+ alpha : float
How much should we trust the prior ([0,1])
- numItermax: int, optional
+ numItermax : int, optional
Max number of iterations
- stopThr: float, optional
+ stopThr : float, optional
Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
- record log if True
-
-
+ record log if True
+
+
Returns
-------
- a: (d,) ndarray
+ a : (d,) ndarray
Wasserstein barycenter
- log: dict
- log dictionary return only if log==True in parameters
-
+ log : dict
+ log dictionary return only if log==True in parameters
+
References
----------
-
+
.. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016.
"""
-
- #M = M/np.median(M)
+
+ #M = M/np.median(M)
K = np.exp(-M/reg)
-
- #M0 = M0/np.median(M0)
+
+ #M0 = M0/np.median(M0)
K0 = np.exp(-M0/reg0)
old = h0
err=1
- cpt=0
+ cpt=0
#log = {'niter':0, 'all_err':[]}
if log:
log={'err':[]}
-
-
+
+
while (err>stopThr and cpt<numItermax):
- K = projC(K,a)
- K0 = projC(K0,h0)
+ K = projC(K,a)
+ K0 = projC(K0,h0)
new = np.sum(K0,axis=1)
inv_new = np.dot(D,new) # we recombine the current selection from dictionnary
other = np.sum(K,axis=1)
delta = np.exp(alpha*np.log(other)+(1-alpha)*np.log(inv_new)) # geometric interpolation
K = projR(K,delta)
K0 = np.dot(np.diag(np.dot(D.T,delta/inv_new)),K0)
-
+
err=np.linalg.norm(np.sum(K0,axis=1)-old)
old = new
if log:
log['err'].append(err)
-
+
if verbose:
if cpt%200 ==0:
print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
- print('{:5d}|{:8e}|'.format(cpt,err))
-
+ print('{:5d}|{:8e}|'.format(cpt,err))
+
cpt = cpt+1
-
+
if log:
log['niter']=cpt
return np.sum(K0,axis=1),log