summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 14:16:23 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 14:16:23 +0200
commitf70aabfcc11f92181e0dc987b341bad8ec030d75 (patch)
tree3f7209a6f8421294fe030cab3fbad49904413e4e /ot/gromov.py
parent6484c9ea301fc15ae53b4afe134941909f581ffe (diff)
pep8
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py124
1 files changed, 62 insertions, 62 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 297b194..fe4fc15 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -78,16 +78,16 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
if loss_fun == 'square_loss':
def f1(a):
- return (a**2)
+ return (a**2)
def f2(b):
- return (b**2)
+ return (b**2)
def h1(a):
return a
def h2(b):
- return 2*b
+ return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
return a * np.log(a + 1e-15) - a
@@ -269,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs):
return np.exp(np.divide(tmpsum, ppt))
-def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs):
"""
Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
@@ -344,13 +344,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs):
return gwggrad(constC, hC1, hC2, G)
if log:
- res, log = cg(p, q, 0, 1, f, df, G0,log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
return res, log
else:
- return cg(p, q, 0, 1, f, df, G0,amijo=amijo, **kwargs)
+ return cg(p, q, 0, 1, f, df, G0, amijo=amijo, **kwargs)
-def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=False,**kwargs):
+
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, amijo=False, **kwargs):
"""
Computes the FGW distance between two graphs see [3]
.. math::
@@ -376,7 +377,7 @@ def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=
q : ndarray, shape (nt,)
distribution in the target space
loss_fun : string,optionnal
- loss function used for the solver
+ loss function used for the solver
max_iter : int, optional
Max number of iterations
tol : float, optional
@@ -404,19 +405,20 @@ def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=
International Conference on Machine Learning (ICML). 2019.
"""
- constC,hC1,hC2=init_matrix(C1,C2,p,q,loss_fun)
-
- G0=p[:,None]*q[None,:]
-
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
def f(G):
- return gwloss(constC,hC1,hC2,G)
+ return gwloss(constC, hC1, hC2, G)
+
def df(G):
- return gwggrad(constC,hC1,hC2,G)
-
- return cg(p,q,M,alpha,f,df,G0,amijo=amijo,C1=C1,C2=C2,constC=constC,**kwargs)
+ return gwggrad(constC, hC1, hC2, G)
+
+ return cg(p, q, M, alpha, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs):
"""
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
@@ -485,7 +487,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs)
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, 0, 1, f, df, G0, log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
log['T'] = res
if log:
@@ -883,14 +885,14 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
return C
-def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False,
- p=None,loss_fun='square_loss',max_iter=100, tol=1e-9,
- verbose=False,log=True,init_C=None,init_X=None):
-
+
+def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
+ p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
+ verbose=False, log=True, init_C=None, init_X=None):
"""
Compute the fgw barycenter as presented eq (5) in [3].
----------
- N : integer
+ N : integer
Desired number of samples of the target barycenter
Ys: list of ndarray, each element has shape (ns,d)
Features of all samples
@@ -906,9 +908,9 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
Wether to fix the structure of the barycenter during the updates
fixed_features : bool
Wether to fix the feature of the barycenter during the updates
- init_C : ndarray, shape (N,N), optional
+ init_C : ndarray, shape (N,N), optional
initialization for the barycenters' structure matrix. If not set random init
- init_X : ndarray, shape (N,d), optional
+ init_X : ndarray, shape (N,d), optional
initialization for the barycenters' features. If not set random init
Returns
----------
@@ -926,14 +928,14 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
-
+
class UndefinedParameter(Exception):
pass
-
+
S = len(Cs)
- d = Ys[0].shape[1] #dimension on the node features
+ d = Ys[0].shape[1] # dimension on the node features
if p is None:
- p = np.ones(N)/N
+ p = np.ones(N) / N
Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)]
@@ -944,7 +946,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
if init_C is None:
raise UndefinedParameter('If C is fixed it must be initialized')
else:
- C=init_C
+ C = init_C
else:
if init_C is None:
xalea = np.random.randn(N, 2)
@@ -954,20 +956,20 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
if fixed_features:
if init_X is None:
- raise UndefinedParameter('If X is fixed it must be initialized')
- else :
- X= init_X
+ raise UndefinedParameter('If X is fixed it must be initialized')
+ else:
+ X = init_X
else:
- if init_X is None:
- X=np.zeros((N,d))
+ if init_X is None:
+ X = np.zeros((N, d))
else:
X = init_X
-
- T=[np.outer(p,q) for q in ps]
+
+ T = [np.outer(p, q) for q in ps]
# X is N,d
# Ys is ns,d
- Ms = [np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
# Ms is N,ns
cpt = 0
@@ -975,46 +977,46 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
err_structure = 1
if log:
- log_={}
- log_['err_feature']=[]
- log_['err_structure']=[]
- log_['Ts_iter']=[]
+ log_ = {}
+ log_['err_feature'] = []
+ log_['err_structure'] = []
+ log_['Ts_iter'] = []
while((err_feature > tol or err_structure > tol) and cpt < max_iter):
Cprev = C
Xprev = X
if not fixed_features:
- Ys_temp=[y.T for y in Ys]
- X=update_feature_matrix(lambdas,Ys_temp,T,p).T
+ Ys_temp = [y.T for y in Ys]
+ X = update_feature_matrix(lambdas, Ys_temp, T, p).T
# X must be N,d
# Ys must be ns,d
- Ms=[np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
if not fixed_structure:
if loss_fun == 'square_loss':
# T must be ns,N
# Cs must be ns,ns
# p must be N,1
- T_temp=[t.T for t in T]
+ T_temp = [t.T for t in T]
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
# Ys must be d,ns
# Ts must be N,ns
# p must be N,1
# Ms is N,ns
- # C is N,N
+ # C is N,N
# Cs is ns,ns
# p is N,1
# ps is ns,1
-
- T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
- # T is N,ns
+ T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
+
+ # T is N,ns
log_['Ts_iter'].append(T)
- err_feature = np.linalg.norm(X - Xprev.reshape(N,d))
+ err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
err_structure = np.linalg.norm(C - Cprev)
if log:
@@ -1029,11 +1031,11 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
print('{:5d}|{:8e}|'.format(cpt, err_feature))
cpt += 1
- log_['T']=T # from target to Ys
- log_['p']=p
- log_['Ms']=Ms #Ms are N,ns
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms # Ms are N,ns
- return X,C,log_
+ return X, C, log_
def update_sructure_matrix(p, lambdas, T, Cs):
@@ -1060,8 +1062,8 @@ def update_sructure_matrix(p, lambdas, T, Cs):
return np.divide(tmpsum, ppt)
-def update_feature_matrix(lambdas,Ys,Ts,p):
-
+
+def update_feature_matrix(lambdas, Ys, Ts, p):
"""
Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3]
calculated at each iteration
@@ -1078,7 +1080,7 @@ def update_feature_matrix(lambdas,Ys,Ts,p):
Returns
----------
X : ndarray, shape (d,N)
-
+
References
----------
.. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
@@ -1087,10 +1089,8 @@ def update_feature_matrix(lambdas,Ys,Ts,p):
International Conference on Machine Learning (ICML). 2019.
"""
- p=np.diag(np.array(1/p).reshape(-1,))
+ p = np.diag(np.array(1 / p).reshape(-1,))
- tmpsum = sum([lambdas[s] * np.dot(Ys[s],Ts[s].T).dot(p) for s in range(len(Ts))])
+ tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T).dot(p) for s in range(len(Ts))])
return tmpsum
-
-