diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-28 10:58:04 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-28 10:58:04 +0200 |
commit | 8cd50c55f398cc371db2ef334c803dec99cc209a (patch) | |
tree | d6084ced937c38603dab8a72d0cc5e64aaf83480 /ot/optim.py | |
parent | a0d8139af3407e567e1dc9a5e8c10d9218ddd185 (diff) |
update doc optim+bregman; add log to sinkhorn
Diffstat (limited to 'ot/optim.py')
-rw-r--r-- | ot/optim.py | 58 |
1 files changed, 52 insertions, 6 deletions
diff --git a/ot/optim.py b/ot/optim.py index e6373ce..d1bf672 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- """ -Created on Wed Oct 26 15:08:19 2016 - -@author: rflamary +Optimization algorithms for OT """ import numpy as np @@ -12,6 +10,42 @@ from lp import emd # The corresponding scipy function does not work for matrices def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99): + """ + Armijo linesearch function that works with matrices + + find an approximate minimum of f(xk+alpha*pk) that satifies the + armijo conditions. + + Parameters + ---------- + + f : function + loss function + xk : np.ndarray + initial position + pk : np.ndarray + descent direction + gfk : np.ndarray + gradient of f at xk + old_fval: float + loss value at xk + args : tuple, optional + arguments given to f + c1 : float, optional + c1 const in armijo rule (>0) + alpha0 : float, optional + initial step (>0) + + Returns + ------- + alpha : float + step that satisfy armijo conditions + fc : int + nb of function call + fa : float + loss value at step alpha + + """ xk = np.atleast_1d(xk) fc = [0] @@ -61,14 +95,26 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa samples in the target domain M : np.ndarray (ns,nt) loss matrix - reg: float() + reg : float Regularization term >0 - + G0 : np.ndarray (ns,nt), optional + initial guess (default is indep joint density) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : int, optional + Print information along iterations + log : int, optional + record log if True Returns ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters + log: dict + log dictionary return only if log==True in parameters + References ---------- @@ -77,7 +123,7 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa See Also -------- - ot.emd.emd : Unregularized optimal ransport + ot.lp.emd : Unregularized optimal ransport ot.bregman.sinkhorn : Entropic regularized optimal transport """ |