From e235b08c2ac86f75b6c1b8e96e503305aa0449e1 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Wed, 17 Nov 2021 11:16:24 +0100 Subject: [MRG] SinkhornL1L2Transport bug (#312) * solve bug * Linesearch no longer return None as alpha, only 0 --- ot/optim.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'ot') diff --git a/ot/optim.py b/ot/optim.py index bd8ca26..cacec53 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -77,10 +77,12 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, alpha, phi1 = scalar_search_armijo( phi, phi0, derphi0, c1=c1, alpha0=alpha0) - # scalar_search_armijo can return alpha > 1 - if alpha is not None: + if alpha is None: + return 0., fc[0], phi0 + else: + # scalar_search_armijo can return alpha > 1 alpha = min(1, alpha) - return alpha, fc[0], phi1 + return alpha, fc[0], phi1 def solve_linesearch(cost, G, deltaG, Mi, f_val, @@ -273,8 +275,6 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, # line search alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) - if alpha is None: - alpha = 0.0 G = G + alpha * deltaG -- cgit v1.2.3