summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-28 11:47:22 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-28 11:47:22 +0200
commit062d6fd23baa27e0334023af320230120b50c828 (patch)
tree5eb36c71552f0f814ea7491b931636e3cc2b46fb /ot
parent3067c8873bf325808985453f0ac968d13435e032 (diff)
bregman doc finished
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py89
1 files changed, 78 insertions, 11 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index b6cdf80..9183bea 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -183,7 +183,7 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=Fa
----------
A : np.ndarray (d,n)
n training distributions of size d
- M : np.ndarray (ns,nt)
+ M : np.ndarray (d,d)
loss matrix for OT
reg: float
Regularization term >0
@@ -256,15 +256,73 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=Fa
return geometricBar(weights,UKv)
-def unmix(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log=dict()):
+def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=False,log=False):
"""
- distrib : distribution to unmix
+ Compute the unmixing of an observation with a given dictionary using Wasserstein distance
+
+ The function solve the following optimization problem:
+
+ .. math::
+ \mathbf{h} = arg\min_\mathbf{h} (1- \\alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\\alpha W_{M0,reg0}(\mathbf{h}_0,\mathbf{h})
+
+
+ where :
+
+ - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}` is an observed distribution, :math:`\mathbf{h}_0` is aprior on unmixing
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT data fitting
+ - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix for regularization
+ - :math:`\\alpha`weight data fitting and regularization
+
+ The optimization problem is solved suing the algorithm described in [4]
+
+
+ distrib : distribution to unmix
D : Dictionnary
M : Metric matrix in the space of the distributions to unmix
M0 : Metric matrix in the space of the 'abundance values' to solve for
h0 : prior on solution (generally uniform distribution)
reg,reg0 : transport regularizations
alpha : how much should we trust the prior ? ([0,1])
+
+ Parameters
+ ----------
+ a : np.ndarray (d)
+ observed distribution
+ D : np.ndarray (d,n)
+ dictionary matrix
+ M : np.ndarray (d,d)
+ loss matrix
+ M0 : np.ndarray (n,n)
+ loss matrix
+ h0 : np.ndarray (n,)
+ prior on h
+ reg: float
+ Regularization term >0 (Wasserstein data fitting)
+ reg0: float
+ Regularization term >0 (Wasserstein reg with h0)
+ numItermax: int, optional
+ Max number of iterations
+ stopThr: float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a: (d,) ndarray
+ Wasserstein barycenter
+ log: dict
+ log dictionary return only if log==True in parameters
+
+ References
+ ----------
+
+ .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016.
+
"""
#M = M/np.median(M)
@@ -277,12 +335,12 @@ def unmix(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log
err=1
cpt=0
#log = {'niter':0, 'all_err':[]}
- log['niter']=0
- log['all_err']=[]
+ if log:
+ log={'err':[]}
- while (err>tol_error and cpt<numItermax):
- K = projC(K,distrib)
+ while (err>stopThr and cpt<numItermax):
+ K = projC(K,a)
K0 = projC(K0,h0)
new = np.sum(K0,axis=1)
inv_new = np.dot(D,new) # we recombine the current selection from dictionnary
@@ -293,9 +351,18 @@ def unmix(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log
err=np.linalg.norm(np.sum(K0,axis=1)-old)
old = new
- log['all_err'].append(err)
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt%200 ==0:
+ print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
+ print('{:5d}|{:8e}|'.format(cpt,err))
+
cpt = cpt+1
-
- log['niter']=cpt
- return np.sum(K0,axis=1),log
+ if log:
+ log['niter']=cpt
+ return np.sum(K0,axis=1),log
+ else:
+ return np.sum(K0,axis=1)