From f70aabfcc11f92181e0dc987b341bad8ec030d75 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 14:16:23 +0200 Subject: pep8 --- ot/optim.py | 59 +++++++++++++++++++++++++++++------------------------------ 1 file changed, 29 insertions(+), 30 deletions(-) (limited to 'ot/optim.py') diff --git a/ot/optim.py b/ot/optim.py index 9fce21e..cbfb187 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -71,8 +71,9 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, return alpha, fc[0], phi1 -def do_linesearch(cost,G,deltaG,Mi,f_val, - amijo=False,C1=None,C2=None,reg=None,Gc=None,constC=None,M=None): + +def do_linesearch(cost, G, deltaG, Mi, f_val, + amijo=False, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations Parameters @@ -119,22 +120,22 @@ def do_linesearch(cost,G,deltaG,Mi,f_val, """ if amijo: alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) - else: # requires symetric matrices - dot1=np.dot(C1,deltaG) - dot12=dot1.dot(C2) - a=-2*reg*np.sum(dot12*deltaG) - b=np.sum((M+reg*constC)*deltaG)-2*reg*(np.sum(dot12*G)+np.sum(np.dot(C1,G).dot(C2)*deltaG)) - c=cost(G) + else: # requires symetric matrices + dot1 = np.dot(C1, deltaG) + dot12 = dot1.dot(C2) + a = -2 * reg * np.sum(dot12 * deltaG) + b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) + c = cost(G) + + alpha = solve_1d_linesearch_quad_funct(a, b, c) + fc = None + f_val = cost(G + alpha * deltaG) - alpha=solve_1d_linesearch_quad_funct(a,b,c) - fc=None - f_val=cost(G+alpha*deltaG) - - return alpha,fc,f_val + return alpha, fc, f_val def cg(a, b, M, reg, f, df, G0=None, numItermax=200, - stopThr=1e-9, verbose=False, log=False,**kwargs): + stopThr=1e-9, verbose=False, log=False, **kwargs): """ Solve the general regularized OT problem with conditional gradient @@ -240,7 +241,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, deltaG = Gc - G # line search - alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,**kwargs) + alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) G = G + alpha * deltaG @@ -403,11 +404,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, else: return G -def solve_1d_linesearch_quad_funct(a,b,c): + +def solve_1d_linesearch_quad_funct(a, b, c): """ - Solve on 0,1 the following problem: + Solve on 0,1 the following problem: .. math:: - \min f(x)=a*x^{2}+b*x+c + \min f(x)=a*x^{2}+b*x+c Parameters ---------- @@ -416,22 +418,19 @@ def solve_1d_linesearch_quad_funct(a,b,c): Returns ------- - x : float + x : float The optimal value which leads to the minimal cost - + """ - f0=c - df0=b - f1=a+f0+df0 + f0 = c + df0 = b + f1 = a + f0 + df0 - if a>0: # convex - minimum=min(1,max(0,-b/(2*a))) - #print('entrelesdeux') + if a > 0: # convex + minimum = min(1, max(0, -b / (2 * a))) return minimum - else: # non convexe donc sur les coins - if f0>f1: - #print('sur1 f(1)={}'.format(f(1))) + else: # non convex + if f0 > f1: return 1 else: - #print('sur0 f(0)={}'.format(f(0))) return 0 -- cgit v1.2.3