From 8b9e641645f7825255afacfd141bfbf52ba2857e Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Wed, 17 Nov 2021 18:59:59 +0100 Subject: [MRG] SinkhornL1L2 bug solve (#313) * Now limiting alpha to a minimum value as well as a max value * Docs * typo --- ot/optim.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/ot/optim.py b/ot/optim.py index cacec53..9b8a8f7 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -18,8 +18,10 @@ from .backend import get_backend # The corresponding scipy function does not work for matrices -def line_search_armijo(f, xk, pk, gfk, old_fval, - args=(), c1=1e-4, alpha0=0.99): +def line_search_armijo( + f, xk, pk, gfk, old_fval, args=(), c1=1e-4, + alpha0=0.99, alpha_min=None, alpha_max=None +): r""" Armijo linesearch function that works with matrices @@ -44,6 +46,10 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, :math:`c_1` const in armijo rule (>0) alpha0 : float, optional initial step (>0) + alpha_min : float, optional + minimum value for alpha + alpha_max : float, optional + maximum value for alpha Returns ------- @@ -80,13 +86,15 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, if alpha is None: return 0., fc[0], phi0 else: - # scalar_search_armijo can return alpha > 1 - alpha = min(1, alpha) + if alpha_min is not None or alpha_max is not None: + alpha = np.clip(alpha, alpha_min, alpha_max) return alpha, fc[0], phi1 -def solve_linesearch(cost, G, deltaG, Mi, f_val, - armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): +def solve_linesearch( + cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None, + reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None +): """ Solve the linesearch in the FW iterations @@ -117,6 +125,10 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, Constant for the gromov cost. See :ref:`[24] `. Only used and necessary when armijo=False M : array-like (ns,nt), optional Cost matrix between the features. Only used and necessary when armijo=False + alpha_min : float, optional + Minimum value for alpha + alpha_max : float, optional + Maximum value for alpha Returns ------- @@ -136,7 +148,9 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, International Conference on Machine Learning (ICML). 2019. """ if armijo: - alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + alpha, fc, f_val = line_search_armijo( + cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max + ) else: # requires symetric matrices G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M) if isinstance(M, int) or isinstance(M, float): @@ -150,6 +164,8 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, c = cost(G) alpha = solve_1d_linesearch_quad(a, b, c) + if alpha_min is not None or alpha_max is not None: + alpha = np.clip(alpha, alpha_min, alpha_max) fc = None f_val = cost(G + alpha * deltaG) @@ -274,7 +290,10 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, deltaG = Gc - G # line search - alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) + alpha, fc, f_val = solve_linesearch( + cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, + alpha_min=0., alpha_max=1., **kwargs + ) G = G + alpha * deltaG @@ -420,7 +439,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, # line search dcost = Mi + reg1 * (1 + nx.log(G)) # ?? - alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val) + alpha, fc, f_val = line_search_armijo( + cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1. + ) G = G + alpha * deltaG -- cgit v1.2.3