diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-28 11:47:22 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-28 11:47:22 +0200 |
commit | 062d6fd23baa27e0334023af320230120b50c828 (patch) | |
tree | 5eb36c71552f0f814ea7491b931636e3cc2b46fb /ot/bregman.py | |
parent | 3067c8873bf325808985453f0ac968d13435e032 (diff) |
bregman doc finished
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 89 |
1 files changed, 78 insertions, 11 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index b6cdf80..9183bea 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -183,7 +183,7 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=Fa ---------- A : np.ndarray (d,n) n training distributions of size d - M : np.ndarray (ns,nt) + M : np.ndarray (d,d) loss matrix for OT reg: float Regularization term >0 @@ -256,15 +256,73 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=Fa return geometricBar(weights,UKv) -def unmix(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log=dict()): +def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=False,log=False): """ - distrib : distribution to unmix + Compute the unmixing of an observation with a given dictionary using Wasserstein distance + + The function solve the following optimization problem: + + .. math:: + \mathbf{h} = arg\min_\mathbf{h} (1- \\alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\\alpha W_{M0,reg0}(\mathbf{h}_0,\mathbf{h}) + + + where : + + - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn) + - :math:`\mathbf{a}` is an observed distribution, :math:`\mathbf{h}_0` is aprior on unmixing + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT data fitting + - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix for regularization + - :math:`\\alpha`weight data fitting and regularization + + The optimization problem is solved suing the algorithm described in [4] + + + distrib : distribution to unmix D : Dictionnary M : Metric matrix in the space of the distributions to unmix M0 : Metric matrix in the space of the 'abundance values' to solve for h0 : prior on solution (generally uniform distribution) reg,reg0 : transport regularizations alpha : how much should we trust the prior ? ([0,1]) + + Parameters + ---------- + a : np.ndarray (d) + observed distribution + D : np.ndarray (d,n) + dictionary matrix + M : np.ndarray (d,d) + loss matrix + M0 : np.ndarray (n,n) + loss matrix + h0 : np.ndarray (n,) + prior on h + reg: float + Regularization term >0 (Wasserstein data fitting) + reg0: float + Regularization term >0 (Wasserstein reg with h0) + numItermax: int, optional + Max number of iterations + stopThr: float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a: (d,) ndarray + Wasserstein barycenter + log: dict + log dictionary return only if log==True in parameters + + References + ---------- + + .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016. + """ #M = M/np.median(M) @@ -277,12 +335,12 @@ def unmix(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log err=1 cpt=0 #log = {'niter':0, 'all_err':[]} - log['niter']=0 - log['all_err']=[] + if log: + log={'err':[]} - while (err>tol_error and cpt<numItermax): - K = projC(K,distrib) + while (err>stopThr and cpt<numItermax): + K = projC(K,a) K0 = projC(K0,h0) new = np.sum(K0,axis=1) inv_new = np.dot(D,new) # we recombine the current selection from dictionnary @@ -293,9 +351,18 @@ def unmix(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log err=np.linalg.norm(np.sum(K0,axis=1)-old) old = new - log['all_err'].append(err) + if log: + log['err'].append(err) + + if verbose: + if cpt%200 ==0: + print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) + print('{:5d}|{:8e}|'.format(cpt,err)) + cpt = cpt+1 - - log['niter']=cpt - return np.sum(K0,axis=1),log + if log: + log['niter']=cpt + return np.sum(K0,axis=1),log + else: + return np.sum(K0,axis=1) |