summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py47
1 files changed, 34 insertions, 13 deletions
diff --git a/ot/optim.py b/ot/optim.py
index bd8ca26..f25e2c9 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -18,8 +18,10 @@ from .backend import get_backend
# The corresponding scipy function does not work for matrices
-def line_search_armijo(f, xk, pk, gfk, old_fval,
- args=(), c1=1e-4, alpha0=0.99):
+def line_search_armijo(
+ f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
+ alpha0=0.99, alpha_min=None, alpha_max=None
+):
r"""
Armijo linesearch function that works with matrices
@@ -44,6 +46,10 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
:math:`c_1` const in armijo rule (>0)
alpha0 : float, optional
initial step (>0)
+ alpha_min : float, optional
+ minimum value for alpha
+ alpha_max : float, optional
+ maximum value for alpha
Returns
-------
@@ -77,14 +83,18 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
- # scalar_search_armijo can return alpha > 1
- if alpha is not None:
- alpha = min(1, alpha)
- return alpha, fc[0], phi1
+ if alpha is None:
+ return 0., fc[0], phi0
+ else:
+ if alpha_min is not None or alpha_max is not None:
+ alpha = np.clip(alpha, alpha_min, alpha_max)
+ return float(alpha), fc[0], phi1
-def solve_linesearch(cost, G, deltaG, Mi, f_val,
- armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
+def solve_linesearch(
+ cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None,
+ reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None
+):
"""
Solve the linesearch in the FW iterations
@@ -115,6 +125,10 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
M : array-like (ns,nt), optional
Cost matrix between the features. Only used and necessary when armijo=False
+ alpha_min : float, optional
+ Minimum value for alpha
+ alpha_max : float, optional
+ Maximum value for alpha
Returns
-------
@@ -134,7 +148,9 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
International Conference on Machine Learning (ICML). 2019.
"""
if armijo:
- alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
+ alpha, fc, f_val = line_search_armijo(
+ cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max
+ )
else: # requires symetric matrices
G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M)
if isinstance(M, int) or isinstance(M, float):
@@ -148,6 +164,8 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
c = cost(G)
alpha = solve_1d_linesearch_quad(a, b, c)
+ if alpha_min is not None or alpha_max is not None:
+ alpha = np.clip(alpha, alpha_min, alpha_max)
fc = None
f_val = cost(G + alpha * deltaG)
@@ -272,9 +290,10 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
deltaG = Gc - G
# line search
- alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
- if alpha is None:
- alpha = 0.0
+ alpha, fc, f_val = solve_linesearch(
+ cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,
+ alpha_min=0., alpha_max=1., **kwargs
+ )
G = G + alpha * deltaG
@@ -420,7 +439,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
# line search
dcost = Mi + reg1 * (1 + nx.log(G)) # ??
- alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val)
+ alpha, fc, f_val = line_search_armijo(
+ cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1.
+ )
G = G + alpha * deltaG