summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 14:16:23 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 14:16:23 +0200
commitf70aabfcc11f92181e0dc987b341bad8ec030d75 (patch)
tree3f7209a6f8421294fe030cab3fbad49904413e4e /ot
parent6484c9ea301fc15ae53b4afe134941909f581ffe (diff)
pep8
Diffstat (limited to 'ot')
-rw-r--r--ot/gromov.py124
-rw-r--r--ot/optim.py59
2 files changed, 91 insertions, 92 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 297b194..fe4fc15 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -78,16 +78,16 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
if loss_fun == 'square_loss':
def f1(a):
- return (a**2)
+ return (a**2)
def f2(b):
- return (b**2)
+ return (b**2)
def h1(a):
return a
def h2(b):
- return 2*b
+ return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
return a * np.log(a + 1e-15) - a
@@ -269,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs):
return np.exp(np.divide(tmpsum, ppt))
-def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs):
"""
Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
@@ -344,13 +344,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs):
return gwggrad(constC, hC1, hC2, G)
if log:
- res, log = cg(p, q, 0, 1, f, df, G0,log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
return res, log
else:
- return cg(p, q, 0, 1, f, df, G0,amijo=amijo, **kwargs)
+ return cg(p, q, 0, 1, f, df, G0, amijo=amijo, **kwargs)
-def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=False,**kwargs):
+
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, amijo=False, **kwargs):
"""
Computes the FGW distance between two graphs see [3]
.. math::
@@ -376,7 +377,7 @@ def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=
q : ndarray, shape (nt,)
distribution in the target space
loss_fun : string,optionnal
- loss function used for the solver
+ loss function used for the solver
max_iter : int, optional
Max number of iterations
tol : float, optional
@@ -404,19 +405,20 @@ def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=
International Conference on Machine Learning (ICML). 2019.
"""
- constC,hC1,hC2=init_matrix(C1,C2,p,q,loss_fun)
-
- G0=p[:,None]*q[None,:]
-
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
def f(G):
- return gwloss(constC,hC1,hC2,G)
+ return gwloss(constC, hC1, hC2, G)
+
def df(G):
- return gwggrad(constC,hC1,hC2,G)
-
- return cg(p,q,M,alpha,f,df,G0,amijo=amijo,C1=C1,C2=C2,constC=constC,**kwargs)
+ return gwggrad(constC, hC1, hC2, G)
+
+ return cg(p, q, M, alpha, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs):
"""
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
@@ -485,7 +487,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs)
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, 0, 1, f, df, G0, log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
log['T'] = res
if log:
@@ -883,14 +885,14 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
return C
-def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False,
- p=None,loss_fun='square_loss',max_iter=100, tol=1e-9,
- verbose=False,log=True,init_C=None,init_X=None):
-
+
+def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
+ p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
+ verbose=False, log=True, init_C=None, init_X=None):
"""
Compute the fgw barycenter as presented eq (5) in [3].
----------
- N : integer
+ N : integer
Desired number of samples of the target barycenter
Ys: list of ndarray, each element has shape (ns,d)
Features of all samples
@@ -906,9 +908,9 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
Wether to fix the structure of the barycenter during the updates
fixed_features : bool
Wether to fix the feature of the barycenter during the updates
- init_C : ndarray, shape (N,N), optional
+ init_C : ndarray, shape (N,N), optional
initialization for the barycenters' structure matrix. If not set random init
- init_X : ndarray, shape (N,d), optional
+ init_X : ndarray, shape (N,d), optional
initialization for the barycenters' features. If not set random init
Returns
----------
@@ -926,14 +928,14 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
-
+
class UndefinedParameter(Exception):
pass
-
+
S = len(Cs)
- d = Ys[0].shape[1] #dimension on the node features
+ d = Ys[0].shape[1] # dimension on the node features
if p is None:
- p = np.ones(N)/N
+ p = np.ones(N) / N
Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)]
@@ -944,7 +946,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
if init_C is None:
raise UndefinedParameter('If C is fixed it must be initialized')
else:
- C=init_C
+ C = init_C
else:
if init_C is None:
xalea = np.random.randn(N, 2)
@@ -954,20 +956,20 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
if fixed_features:
if init_X is None:
- raise UndefinedParameter('If X is fixed it must be initialized')
- else :
- X= init_X
+ raise UndefinedParameter('If X is fixed it must be initialized')
+ else:
+ X = init_X
else:
- if init_X is None:
- X=np.zeros((N,d))
+ if init_X is None:
+ X = np.zeros((N, d))
else:
X = init_X
-
- T=[np.outer(p,q) for q in ps]
+
+ T = [np.outer(p, q) for q in ps]
# X is N,d
# Ys is ns,d
- Ms = [np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
# Ms is N,ns
cpt = 0
@@ -975,46 +977,46 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
err_structure = 1
if log:
- log_={}
- log_['err_feature']=[]
- log_['err_structure']=[]
- log_['Ts_iter']=[]
+ log_ = {}
+ log_['err_feature'] = []
+ log_['err_structure'] = []
+ log_['Ts_iter'] = []
while((err_feature > tol or err_structure > tol) and cpt < max_iter):
Cprev = C
Xprev = X
if not fixed_features:
- Ys_temp=[y.T for y in Ys]
- X=update_feature_matrix(lambdas,Ys_temp,T,p).T
+ Ys_temp = [y.T for y in Ys]
+ X = update_feature_matrix(lambdas, Ys_temp, T, p).T
# X must be N,d
# Ys must be ns,d
- Ms=[np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
if not fixed_structure:
if loss_fun == 'square_loss':
# T must be ns,N
# Cs must be ns,ns
# p must be N,1
- T_temp=[t.T for t in T]
+ T_temp = [t.T for t in T]
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
# Ys must be d,ns
# Ts must be N,ns
# p must be N,1
# Ms is N,ns
- # C is N,N
+ # C is N,N
# Cs is ns,ns
# p is N,1
# ps is ns,1
-
- T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
- # T is N,ns
+ T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
+
+ # T is N,ns
log_['Ts_iter'].append(T)
- err_feature = np.linalg.norm(X - Xprev.reshape(N,d))
+ err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
err_structure = np.linalg.norm(C - Cprev)
if log:
@@ -1029,11 +1031,11 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
print('{:5d}|{:8e}|'.format(cpt, err_feature))
cpt += 1
- log_['T']=T # from target to Ys
- log_['p']=p
- log_['Ms']=Ms #Ms are N,ns
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms # Ms are N,ns
- return X,C,log_
+ return X, C, log_
def update_sructure_matrix(p, lambdas, T, Cs):
@@ -1060,8 +1062,8 @@ def update_sructure_matrix(p, lambdas, T, Cs):
return np.divide(tmpsum, ppt)
-def update_feature_matrix(lambdas,Ys,Ts,p):
-
+
+def update_feature_matrix(lambdas, Ys, Ts, p):
"""
Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3]
calculated at each iteration
@@ -1078,7 +1080,7 @@ def update_feature_matrix(lambdas,Ys,Ts,p):
Returns
----------
X : ndarray, shape (d,N)
-
+
References
----------
.. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
@@ -1087,10 +1089,8 @@ def update_feature_matrix(lambdas,Ys,Ts,p):
International Conference on Machine Learning (ICML). 2019.
"""
- p=np.diag(np.array(1/p).reshape(-1,))
+ p = np.diag(np.array(1 / p).reshape(-1,))
- tmpsum = sum([lambdas[s] * np.dot(Ys[s],Ts[s].T).dot(p) for s in range(len(Ts))])
+ tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T).dot(p) for s in range(len(Ts))])
return tmpsum
-
-
diff --git a/ot/optim.py b/ot/optim.py
index 9fce21e..cbfb187 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -71,8 +71,9 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
return alpha, fc[0], phi1
-def do_linesearch(cost,G,deltaG,Mi,f_val,
- amijo=False,C1=None,C2=None,reg=None,Gc=None,constC=None,M=None):
+
+def do_linesearch(cost, G, deltaG, Mi, f_val,
+ amijo=False, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
"""
Solve the linesearch in the FW iterations
Parameters
@@ -119,22 +120,22 @@ def do_linesearch(cost,G,deltaG,Mi,f_val,
"""
if amijo:
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
- else: # requires symetric matrices
- dot1=np.dot(C1,deltaG)
- dot12=dot1.dot(C2)
- a=-2*reg*np.sum(dot12*deltaG)
- b=np.sum((M+reg*constC)*deltaG)-2*reg*(np.sum(dot12*G)+np.sum(np.dot(C1,G).dot(C2)*deltaG))
- c=cost(G)
+ else: # requires symetric matrices
+ dot1 = np.dot(C1, deltaG)
+ dot12 = dot1.dot(C2)
+ a = -2 * reg * np.sum(dot12 * deltaG)
+ b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG))
+ c = cost(G)
+
+ alpha = solve_1d_linesearch_quad_funct(a, b, c)
+ fc = None
+ f_val = cost(G + alpha * deltaG)
- alpha=solve_1d_linesearch_quad_funct(a,b,c)
- fc=None
- f_val=cost(G+alpha*deltaG)
-
- return alpha,fc,f_val
+ return alpha, fc, f_val
def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
- stopThr=1e-9, verbose=False, log=False,**kwargs):
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
"""
Solve the general regularized OT problem with conditional gradient
@@ -240,7 +241,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
deltaG = Gc - G
# line search
- alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,**kwargs)
+ alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
G = G + alpha * deltaG
@@ -403,11 +404,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
else:
return G
-def solve_1d_linesearch_quad_funct(a,b,c):
+
+def solve_1d_linesearch_quad_funct(a, b, c):
"""
- Solve on 0,1 the following problem:
+ Solve on 0,1 the following problem:
.. math::
- \min f(x)=a*x^{2}+b*x+c
+ \min f(x)=a*x^{2}+b*x+c
Parameters
----------
@@ -416,22 +418,19 @@ def solve_1d_linesearch_quad_funct(a,b,c):
Returns
-------
- x : float
+ x : float
The optimal value which leads to the minimal cost
-
+
"""
- f0=c
- df0=b
- f1=a+f0+df0
+ f0 = c
+ df0 = b
+ f1 = a + f0 + df0
- if a>0: # convex
- minimum=min(1,max(0,-b/(2*a)))
- #print('entrelesdeux')
+ if a > 0: # convex
+ minimum = min(1, max(0, -b / (2 * a)))
return minimum
- else: # non convexe donc sur les coins
- if f0>f1:
- #print('sur1 f(1)={}'.format(f(1)))
+ else: # non convex
+ if f0 > f1:
return 1
else:
- #print('sur0 f(0)={}'.format(f(0)))
return 0