summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-28 10:58:04 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-28 10:58:04 +0200
commit8cd50c55f398cc371db2ef334c803dec99cc209a (patch)
treed6084ced937c38603dab8a72d0cc5e64aaf83480 /ot/optim.py
parenta0d8139af3407e567e1dc9a5e8c10d9218ddd185 (diff)
update doc optim+bregman; add log to sinkhorn
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py58
1 files changed, 52 insertions, 6 deletions
diff --git a/ot/optim.py b/ot/optim.py
index e6373ce..d1bf672 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
"""
-Created on Wed Oct 26 15:08:19 2016
-
-@author: rflamary
+Optimization algorithms for OT
"""
import numpy as np
@@ -12,6 +10,42 @@ from lp import emd
# The corresponding scipy function does not work for matrices
def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99):
+ """
+ Armijo linesearch function that works with matrices
+
+ find an approximate minimum of f(xk+alpha*pk) that satifies the
+ armijo conditions.
+
+ Parameters
+ ----------
+
+ f : function
+ loss function
+ xk : np.ndarray
+ initial position
+ pk : np.ndarray
+ descent direction
+ gfk : np.ndarray
+ gradient of f at xk
+ old_fval: float
+ loss value at xk
+ args : tuple, optional
+ arguments given to f
+ c1 : float, optional
+ c1 const in armijo rule (>0)
+ alpha0 : float, optional
+ initial step (>0)
+
+ Returns
+ -------
+ alpha : float
+ step that satisfy armijo conditions
+ fc : int
+ nb of function call
+ fa : float
+ loss value at step alpha
+
+ """
xk = np.atleast_1d(xk)
fc = [0]
@@ -61,14 +95,26 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa
samples in the target domain
M : np.ndarray (ns,nt)
loss matrix
- reg: float()
+ reg : float
Regularization term >0
-
+ G0 : np.ndarray (ns,nt), optional
+ initial guess (default is indep joint density)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : int, optional
+ Print information along iterations
+ log : int, optional
+ record log if True
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
+ log: dict
+ log dictionary return only if log==True in parameters
+
References
----------
@@ -77,7 +123,7 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa
See Also
--------
- ot.emd.emd : Unregularized optimal ransport
+ ot.lp.emd : Unregularized optimal ransport
ot.bregman.sinkhorn : Entropic regularized optimal transport
"""