From f70aabfcc11f92181e0dc987b341bad8ec030d75 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 14:16:23 +0200 Subject: pep8 --- ot/gromov.py | 124 +++++++++++++++++++++++++++++------------------------------ ot/optim.py | 59 ++++++++++++++-------------- 2 files changed, 91 insertions(+), 92 deletions(-) diff --git a/ot/gromov.py b/ot/gromov.py index 297b194..fe4fc15 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -78,16 +78,16 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): if loss_fun == 'square_loss': def f1(a): - return (a**2) + return (a**2) def f2(b): - return (b**2) + return (b**2) def h1(a): return a def h2(b): - return 2*b + return 2 * b elif loss_fun == 'kl_loss': def f1(a): return a * np.log(a + 1e-15) - a @@ -269,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs): return np.exp(np.divide(tmpsum, ppt)) -def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs): """ Returns the gromov-wasserstein transport between (C1,p) and (C2,q) @@ -344,13 +344,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs): return gwggrad(constC, hC1, hC2, G) if log: - res, log = cg(p, q, 0, 1, f, df, G0,log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs) + res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) log['gw_dist'] = gwloss(constC, hC1, hC2, res) return res, log else: - return cg(p, q, 0, 1, f, df, G0,amijo=amijo, **kwargs) + return cg(p, q, 0, 1, f, df, G0, amijo=amijo, **kwargs) -def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=False,**kwargs): + +def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, amijo=False, **kwargs): """ Computes the FGW distance between two graphs see [3] .. math:: @@ -376,7 +377,7 @@ def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo= q : ndarray, shape (nt,) distribution in the target space loss_fun : string,optionnal - loss function used for the solver + loss function used for the solver max_iter : int, optional Max number of iterations tol : float, optional @@ -404,19 +405,20 @@ def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo= International Conference on Machine Learning (ICML). 2019. """ - constC,hC1,hC2=init_matrix(C1,C2,p,q,loss_fun) - - G0=p[:,None]*q[None,:] - + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) + + G0 = p[:, None] * q[None, :] + def f(G): - return gwloss(constC,hC1,hC2,G) + return gwloss(constC, hC1, hC2, G) + def df(G): - return gwggrad(constC,hC1,hC2,G) - - return cg(p,q,M,alpha,f,df,G0,amijo=amijo,C1=C1,C2=C2,constC=constC,**kwargs) + return gwggrad(constC, hC1, hC2, G) + + return cg(p, q, M, alpha, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs): """ Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) @@ -485,7 +487,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs) def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, 0, 1, f, df, G0, log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs) + res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) log['gw_dist'] = gwloss(constC, hC1, hC2, res) log['T'] = res if log: @@ -883,14 +885,14 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, return C -def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False, - p=None,loss_fun='square_loss',max_iter=100, tol=1e-9, - verbose=False,log=True,init_C=None,init_X=None): - + +def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, + p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, + verbose=False, log=True, init_C=None, init_X=None): """ Compute the fgw barycenter as presented eq (5) in [3]. ---------- - N : integer + N : integer Desired number of samples of the target barycenter Ys: list of ndarray, each element has shape (ns,d) Features of all samples @@ -906,9 +908,9 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature Wether to fix the structure of the barycenter during the updates fixed_features : bool Wether to fix the feature of the barycenter during the updates - init_C : ndarray, shape (N,N), optional + init_C : ndarray, shape (N,N), optional initialization for the barycenters' structure matrix. If not set random init - init_X : ndarray, shape (N,d), optional + init_X : ndarray, shape (N,d), optional initialization for the barycenters' features. If not set random init Returns ---------- @@ -926,14 +928,14 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ - + class UndefinedParameter(Exception): pass - + S = len(Cs) - d = Ys[0].shape[1] #dimension on the node features + d = Ys[0].shape[1] # dimension on the node features if p is None: - p = np.ones(N)/N + p = np.ones(N) / N Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)] @@ -944,7 +946,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature if init_C is None: raise UndefinedParameter('If C is fixed it must be initialized') else: - C=init_C + C = init_C else: if init_C is None: xalea = np.random.randn(N, 2) @@ -954,20 +956,20 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature if fixed_features: if init_X is None: - raise UndefinedParameter('If X is fixed it must be initialized') - else : - X= init_X + raise UndefinedParameter('If X is fixed it must be initialized') + else: + X = init_X else: - if init_X is None: - X=np.zeros((N,d)) + if init_X is None: + X = np.zeros((N, d)) else: X = init_X - - T=[np.outer(p,q) for q in ps] + + T = [np.outer(p, q) for q in ps] # X is N,d # Ys is ns,d - Ms = [np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))] + Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns cpt = 0 @@ -975,46 +977,46 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature err_structure = 1 if log: - log_={} - log_['err_feature']=[] - log_['err_structure']=[] - log_['Ts_iter']=[] + log_ = {} + log_['err_feature'] = [] + log_['err_structure'] = [] + log_['Ts_iter'] = [] while((err_feature > tol or err_structure > tol) and cpt < max_iter): Cprev = C Xprev = X if not fixed_features: - Ys_temp=[y.T for y in Ys] - X=update_feature_matrix(lambdas,Ys_temp,T,p).T + Ys_temp = [y.T for y in Ys] + X = update_feature_matrix(lambdas, Ys_temp, T, p).T # X must be N,d # Ys must be ns,d - Ms=[np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))] + Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] if not fixed_structure: if loss_fun == 'square_loss': # T must be ns,N # Cs must be ns,ns # p must be N,1 - T_temp=[t.T for t in T] + T_temp = [t.T for t in T] C = update_sructure_matrix(p, lambdas, T_temp, Cs) # Ys must be d,ns # Ts must be N,ns # p must be N,1 # Ms is N,ns - # C is N,N + # C is N,N # Cs is ns,ns # p is N,1 # ps is ns,1 - - T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] - # T is N,ns + T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] + + # T is N,ns log_['Ts_iter'].append(T) - err_feature = np.linalg.norm(X - Xprev.reshape(N,d)) + err_feature = np.linalg.norm(X - Xprev.reshape(N, d)) err_structure = np.linalg.norm(C - Cprev) if log: @@ -1029,11 +1031,11 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature print('{:5d}|{:8e}|'.format(cpt, err_feature)) cpt += 1 - log_['T']=T # from target to Ys - log_['p']=p - log_['Ms']=Ms #Ms are N,ns + log_['T'] = T # from target to Ys + log_['p'] = p + log_['Ms'] = Ms # Ms are N,ns - return X,C,log_ + return X, C, log_ def update_sructure_matrix(p, lambdas, T, Cs): @@ -1060,8 +1062,8 @@ def update_sructure_matrix(p, lambdas, T, Cs): return np.divide(tmpsum, ppt) -def update_feature_matrix(lambdas,Ys,Ts,p): - + +def update_feature_matrix(lambdas, Ys, Ts, p): """ Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3] calculated at each iteration @@ -1078,7 +1080,7 @@ def update_feature_matrix(lambdas,Ys,Ts,p): Returns ---------- X : ndarray, shape (d,N) - + References ---------- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain @@ -1087,10 +1089,8 @@ def update_feature_matrix(lambdas,Ys,Ts,p): International Conference on Machine Learning (ICML). 2019. """ - p=np.diag(np.array(1/p).reshape(-1,)) + p = np.diag(np.array(1 / p).reshape(-1,)) - tmpsum = sum([lambdas[s] * np.dot(Ys[s],Ts[s].T).dot(p) for s in range(len(Ts))]) + tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T).dot(p) for s in range(len(Ts))]) return tmpsum - - diff --git a/ot/optim.py b/ot/optim.py index 9fce21e..cbfb187 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -71,8 +71,9 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, return alpha, fc[0], phi1 -def do_linesearch(cost,G,deltaG,Mi,f_val, - amijo=False,C1=None,C2=None,reg=None,Gc=None,constC=None,M=None): + +def do_linesearch(cost, G, deltaG, Mi, f_val, + amijo=False, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations Parameters @@ -119,22 +120,22 @@ def do_linesearch(cost,G,deltaG,Mi,f_val, """ if amijo: alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) - else: # requires symetric matrices - dot1=np.dot(C1,deltaG) - dot12=dot1.dot(C2) - a=-2*reg*np.sum(dot12*deltaG) - b=np.sum((M+reg*constC)*deltaG)-2*reg*(np.sum(dot12*G)+np.sum(np.dot(C1,G).dot(C2)*deltaG)) - c=cost(G) + else: # requires symetric matrices + dot1 = np.dot(C1, deltaG) + dot12 = dot1.dot(C2) + a = -2 * reg * np.sum(dot12 * deltaG) + b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) + c = cost(G) + + alpha = solve_1d_linesearch_quad_funct(a, b, c) + fc = None + f_val = cost(G + alpha * deltaG) - alpha=solve_1d_linesearch_quad_funct(a,b,c) - fc=None - f_val=cost(G+alpha*deltaG) - - return alpha,fc,f_val + return alpha, fc, f_val def cg(a, b, M, reg, f, df, G0=None, numItermax=200, - stopThr=1e-9, verbose=False, log=False,**kwargs): + stopThr=1e-9, verbose=False, log=False, **kwargs): """ Solve the general regularized OT problem with conditional gradient @@ -240,7 +241,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, deltaG = Gc - G # line search - alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,**kwargs) + alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) G = G + alpha * deltaG @@ -403,11 +404,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, else: return G -def solve_1d_linesearch_quad_funct(a,b,c): + +def solve_1d_linesearch_quad_funct(a, b, c): """ - Solve on 0,1 the following problem: + Solve on 0,1 the following problem: .. math:: - \min f(x)=a*x^{2}+b*x+c + \min f(x)=a*x^{2}+b*x+c Parameters ---------- @@ -416,22 +418,19 @@ def solve_1d_linesearch_quad_funct(a,b,c): Returns ------- - x : float + x : float The optimal value which leads to the minimal cost - + """ - f0=c - df0=b - f1=a+f0+df0 + f0 = c + df0 = b + f1 = a + f0 + df0 - if a>0: # convex - minimum=min(1,max(0,-b/(2*a))) - #print('entrelesdeux') + if a > 0: # convex + minimum = min(1, max(0, -b / (2 * a))) return minimum - else: # non convexe donc sur les coins - if f0>f1: - #print('sur1 f(1)={}'.format(f(1))) + else: # non convex + if f0 > f1: return 1 else: - #print('sur0 f(0)={}'.format(f(0))) return 0 -- cgit v1.2.3