summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 15:51:57 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 15:51:57 +0200
commit9421dddd8890d4c575b593d678eb7bdf5f933f83 (patch)
treeea589599791cf38f7f6c2420d919bc3a627f5ae0 /ot/optim.py
parent94d2fe5fd0b07060426e9449de0331b88ab53df4 (diff)
Doc+armijo
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py22
1 files changed, 11 insertions, 11 deletions
diff --git a/ot/optim.py b/ot/optim.py
index b96d920..82a91bf 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -73,13 +73,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
def do_linesearch(cost, G, deltaG, Mi, f_val,
- amijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
+ armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
"""
Solve the linesearch in the FW iterations
Parameters
----------
cost : method
- The FGW cost
+ Cost in the FW for the linesearch
G : ndarray, shape(ns,nt)
The transport map at a given iteration of the FW
deltaG : ndarray (ns,nt)
@@ -88,21 +88,21 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
f_val : float
Value of the cost at G
- amijo : bool, optionnal
- If True the steps of the line-search is found via an amijo research. Else closed form is used.
+ armijo : bool, optionnal
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
C1 : ndarray (ns,ns), optionnal
- Structure matrix in the source domain. Only used when amijo=False
+ Structure matrix in the source domain. Only used when armijo=False
C2 : ndarray (nt,nt), optionnal
- Structure matrix in the target domain. Only used when amijo=False
+ Structure matrix in the target domain. Only used when armijo=False
reg : float, optionnal
- Regularization parameter. Corresponds to the alpha parameter of FGW. Only used when amijo=False
+ Regularization parameter. Only used when armijo=False
Gc : ndarray (ns,nt)
- Optimal map found by linearization in the FW algorithm. Only used when amijo=False
+ Optimal map found by linearization in the FW algorithm. Only used when armijo=False
constC : ndarray (ns,nt)
- Constant for the gromov cost. See [3]. Only used when amijo=False
+ Constant for the gromov cost. See [24]. Only used when armijo=False
M : ndarray (ns,nt), optionnal
- Cost matrix between the features. Only used when amijo=False
+ Cost matrix between the features. Only used when armijo=False
Returns
-------
alpha : float
@@ -118,7 +118,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
- if amijo:
+ if armijo:
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
else: # requires symetric matrices
dot1 = np.dot(C1, deltaG)