diff options
author | tvayer <titouan.vayer@gmail.com> | 2019-05-29 15:51:57 +0200 |
---|---|---|
committer | tvayer <titouan.vayer@gmail.com> | 2019-05-29 15:51:57 +0200 |
commit | 9421dddd8890d4c575b593d678eb7bdf5f933f83 (patch) | |
tree | ea589599791cf38f7f6c2420d919bc3a627f5ae0 /ot/optim.py | |
parent | 94d2fe5fd0b07060426e9449de0331b88ab53df4 (diff) |
Doc+armijo
Diffstat (limited to 'ot/optim.py')
-rw-r--r-- | ot/optim.py | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/ot/optim.py b/ot/optim.py index b96d920..82a91bf 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -73,13 +73,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, def do_linesearch(cost, G, deltaG, Mi, f_val, - amijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): + armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations Parameters ---------- cost : method - The FGW cost + Cost in the FW for the linesearch G : ndarray, shape(ns,nt) The transport map at a given iteration of the FW deltaG : ndarray (ns,nt) @@ -88,21 +88,21 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost f_val : float Value of the cost at G - amijo : bool, optionnal - If True the steps of the line-search is found via an amijo research. Else closed form is used. + armijo : bool, optionnal + If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. C1 : ndarray (ns,ns), optionnal - Structure matrix in the source domain. Only used when amijo=False + Structure matrix in the source domain. Only used when armijo=False C2 : ndarray (nt,nt), optionnal - Structure matrix in the target domain. Only used when amijo=False + Structure matrix in the target domain. Only used when armijo=False reg : float, optionnal - Regularization parameter. Corresponds to the alpha parameter of FGW. Only used when amijo=False + Regularization parameter. Only used when armijo=False Gc : ndarray (ns,nt) - Optimal map found by linearization in the FW algorithm. Only used when amijo=False + Optimal map found by linearization in the FW algorithm. Only used when armijo=False constC : ndarray (ns,nt) - Constant for the gromov cost. See [3]. Only used when amijo=False + Constant for the gromov cost. See [24]. Only used when armijo=False M : ndarray (ns,nt), optionnal - Cost matrix between the features. Only used when amijo=False + Cost matrix between the features. Only used when armijo=False Returns ------- alpha : float @@ -118,7 +118,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ - if amijo: + if armijo: alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) else: # requires symetric matrices dot1 = np.dot(C1, deltaG) |