summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-06-04 10:32:30 +0200
committertvayer <titouan.vayer@gmail.com>2019-06-04 10:32:35 +0200
commitad450b0a5bb63ee9731e88d4a8e7423b16f1abd8 (patch)
treecab0421292074e59cb4eeb2846e8cca5aa159d3a /ot/optim.py
parent89a2e0aee4353a051d924de0457f8976c26fa5d7 (diff)
changes forgotten coments
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py32
1 files changed, 16 insertions, 16 deletions
diff --git a/ot/optim.py b/ot/optim.py
index 4d428d9..f94aceb 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -73,8 +73,8 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
return alpha, fc[0], phi1
-def do_linesearch(cost, G, deltaG, Mi, f_val,
- armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
+def solve_linesearch(cost, G, deltaG, Mi, f_val,
+ armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
"""
Solve the linesearch in the FW iterations
Parameters
@@ -93,17 +93,17 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
C1 : ndarray (ns,ns), optional
- Structure matrix in the source domain. Only used when armijo=False
+ Structure matrix in the source domain. Only used and necessary when armijo=False
C2 : ndarray (nt,nt), optional
- Structure matrix in the target domain. Only used when armijo=False
+ Structure matrix in the target domain. Only used and necessary when armijo=False
reg : float, optional
- Regularization parameter. Only used when armijo=False
+ Regularization parameter. Only used and necessary when armijo=False
Gc : ndarray (ns,nt)
- Optimal map found by linearization in the FW algorithm. Only used when armijo=False
+ Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
constC : ndarray (ns,nt)
- Constant for the gromov cost. See [24]. Only used when armijo=False
+ Constant for the gromov cost. See [24]. Only used and necessary when armijo=False
M : ndarray (ns,nt), optional
- Cost matrix between the features. Only used when armijo=False
+ Cost matrix between the features. Only used and necessary when armijo=False
Returns
-------
alpha : float
@@ -128,7 +128,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG))
c = cost(G)
- alpha = solve_1d_linesearch_quad_funct(a, b, c)
+ alpha = solve_1d_linesearch_quad(a, b, c)
fc = None
f_val = cost(G + alpha * deltaG)
@@ -181,7 +181,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
Print information along iterations
log : bool, optional
record log if True
- kwargs : dict
+ **kwargs : dict
Parameters for linesearch
Returns
@@ -244,7 +244,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
deltaG = Gc - G
# line search
- alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
+ alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
G = G + alpha * deltaG
@@ -254,7 +254,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
abs_delta_fval = abs(f_val - old_fval)
relative_delta_fval = abs_delta_fval / abs(f_val)
- if relative_delta_fval < stopThr and abs_delta_fval < stopThr2:
+ if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
loop = 0
if log:
@@ -395,7 +395,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
abs_delta_fval = abs(f_val - old_fval)
relative_delta_fval = abs_delta_fval / abs(f_val)
- if relative_delta_fval < stopThr and abs_delta_fval < stopThr2:
+ if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
loop = 0
if log:
@@ -413,11 +413,11 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
return G
-def solve_1d_linesearch_quad_funct(a, b, c):
+def solve_1d_linesearch_quad(a, b, c):
"""
- Solve on 0,1 the following problem:
+ For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem:
.. math::
- \min f(x)=a*x^{2}+b*x+c
+ \argmin f(x)=a*x^{2}+b*x+c
Parameters
----------