summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-28 16:08:41 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-28 16:08:41 +0200
commit549b95b5736b42f3fe74daf9805303a08b1ae01d (patch)
treed4d8ac5252bff2fef688e2fc81087293364b3ac7 /ot/gromov.py
parent327b0c6e0ccb0c9453179eb316021c34bcdffec4 (diff)
FGW+gromov changes
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py310
1 files changed, 292 insertions, 18 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 0278e99..7491664 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -9,17 +9,18 @@ Gromov-Wasserstein transport method
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
# RĂ©mi Flamary <remi.flamary@unice.fr>
-#
+# Titouan Vayer <titouan.vayer@irisa.fr>
# License: MIT License
import numpy as np
+
from .bregman import sinkhorn
from .utils import dist
from .optim import cg
-def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'):
+def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
""" Return loss matrices and tensors for Gromov-Wasserstein fast computation
Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
@@ -77,16 +78,16 @@ def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'):
if loss_fun == 'square_loss':
def f1(a):
- return (a**2) / 2
+ return (a**2)
def f2(b):
- return (b**2) / 2
+ return (b**2)
def h1(a):
return a
def h2(b):
- return b
+ return 2*b
elif loss_fun == 'kl_loss':
def f1(a):
return a * np.log(a + 1e-15) - a
@@ -268,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs):
return np.exp(np.divide(tmpsum, ppt))
-def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs):
"""
Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
@@ -306,6 +307,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs):
Print information along iterations
log : bool, optional
record log if True
+ amijo : bool, optional
+ If True the steps of the line-search is found via an amijo research. Else closed form is used.
+ If there is convergence issues use False.
**kwargs : dict
parameters can be directly pased to the ot.optim.cg solver
@@ -329,9 +333,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs):
"""
- T = np.eye(len(p), len(q))
-
- constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -342,14 +344,79 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs):
return gwggrad(constC, hC1, hC2, G)
if log:
- res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0,log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
return res, log
else:
- return cg(p, q, 0, 1, f, df, G0, **kwargs)
+ return cg(p, q, 0, 1, f, df, G0,amijo=amijo, **kwargs)
+
+def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=False,**kwargs):
+ """
+ Computes the FGW distance between two graphs see [3]
+ .. math::
+ \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ s.t. \gamma 1 = p
+ \gamma^T 1= q
+ \gamma\geq 0
+ where :
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+ Parameters
+ ----------
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix respresentative of the structure in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix espresentative of the structure in the target space
+ p : ndarray, shape (ns,)
+ distribution in the source space
+ q : ndarray, shape (nt,)
+ distribution in the target space
+ loss_fun : string,optionnal
+ loss function used for the solver
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ amijo : bool, optional
+ If True the steps of the line-search is found via an amijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly pased to the ot.optim.cg solver
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+ References
+ ----------
+ .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ constC,hC1,hC2=init_matrix(C1,C2,p,q,loss_fun)
+
+ G0=p[:,None]*q[None,:]
+
+ def f(G):
+ return gwloss(constC,hC1,hC2,G)
+ def df(G):
+ return gwggrad(constC,hC1,hC2,G)
+
+ return cg(p,q,M,alpha,f,df,G0,amijo=amijo,C1=C1,C2=C2,constC=constC,**kwargs)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs):
"""
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
@@ -387,7 +454,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs):
Print information along iterations
log : bool, optional
record log if True
-
+ amijo : bool, optional
+ If True the steps of the line-search is found via an amijo research. Else closed form is used.
+ If there is convergence issues use False.
Returns
-------
gw_dist : float
@@ -407,9 +476,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs):
"""
- T = np.eye(len(p), len(q))
-
- constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -418,7 +485,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs):
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
log['T'] = res
if log:
@@ -495,7 +562,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
T = np.outer(p, q) # Initialization
- constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
cpt = 0
err = 1
@@ -815,3 +882,210 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
cpt += 1
return C
+
+def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False,p=None,loss_fun='square_loss',
+ max_iter=100, tol=1e-9,verbose=False,log=True,init_C=None,init_X=None):
+
+ """
+ Compute the fgw barycenter as presented eq (5) in [3].
+ ----------
+ N : integer
+ Desired number of samples of the target barycenter
+ Ys: list of ndarray, each element has shape (ns,d)
+ Features of all samples
+ Cs : list of ndarray, each element has shape (ns,ns)
+ Structure matrices of all samples
+ ps : list of ndarray, each element has shape (ns,)
+ masses of all samples
+ lambdas : list of float
+ list of the S spaces' weights
+ alpha : float
+ Alpha parameter for the fgw distance
+ fixed_structure : bool
+ Wether to fix the structure of the barycenter during the updates
+ fixed_features : bool
+ Wether to fix the feature of the barycenter during the updates
+ init_C : ndarray, shape (N,N), optional
+ initialization for the barycenters' structure matrix. If not set random init
+ init_X : ndarray, shape (N,d), optional
+ initialization for the barycenters' features. If not set random init
+ Returns
+ ----------
+ X : ndarray, shape (N,d)
+ Barycenters' features
+ C : ndarray, shape (N,N)
+ Barycenters' structure matrix
+ log_:
+ T : list of (N,ns) transport matrices
+ Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns)
+ References
+ ----------
+ .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ S = len(Cs)
+ d = Ys[0].shape[1] #dimension on the node features
+ if p is None:
+ p = np.ones(N)/N
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)]
+
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ if fixed_structure:
+ if init_C is None:
+ C=Cs[0]
+ else:
+ C=init_C
+ else:
+ if init_C is None:
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ else:
+ C = init_C
+
+ if fixed_features:
+ if init_X is None:
+ X=Ys[0]
+ else :
+ X= init_X
+ else:
+ if init_X is None:
+ X=np.zeros((N,d))
+ else:
+ X = init_X
+
+ T=[np.outer(p,q) for q in ps]
+
+ # X is N,d
+ # Ys is ns,d
+ Ms = [np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+ # Ms is N,ns
+
+ cpt = 0
+ err_feature = 1
+ err_structure = 1
+
+ if log:
+ log_={}
+ log_['err_feature']=[]
+ log_['err_structure']=[]
+ log_['Ts_iter']=[]
+
+ while((err_feature > tol or err_structure > tol) and cpt < max_iter):
+ Cprev = C
+ Xprev = X
+
+ if not fixed_features:
+ Ys_temp=[y.T for y in Ys]
+ X=update_feature_matrix(lambdas,Ys_temp,T,p)
+
+ # X must be N,d
+ # Ys must be ns,d
+ Ms=[np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+
+ if not fixed_structure:
+ if loss_fun == 'square_loss':
+ # T must be ns,N
+ # Cs must be ns,ns
+ # p must be N,1
+ T_temp=[t.T for t in T]
+ C = update_sructure_matrix(p, lambdas, T_temp, Cs)
+
+ # Ys must be d,ns
+ # Ts must be N,ns
+ # p must be N,1
+ # Ms is N,ns
+ # C is N,N
+ # Cs is ns,ns
+ # p is N,1
+ # ps is ns,1
+
+ T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
+
+ # T is N,ns
+
+ log_['Ts_iter'].append(T)
+ err_feature = np.linalg.norm(X - Xprev.reshape(d,N))
+ err_structure = np.linalg.norm(C - Cprev)
+
+ if log:
+ log_['err_feature'].append(err_feature)
+ log_['err_structure'].append(err_structure)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err_structure))
+ print('{:5d}|{:8e}|'.format(cpt, err_feature))
+
+ cpt += 1
+ log_['T']=T # ce sont les matrices du barycentre de la target vers les Ys
+ log_['p']=p
+ log_['Ms']=Ms #Ms sont de tailles N,ns
+
+ return X.T,C,log_
+
+
+def update_sructure_matrix(p, lambdas, T, Cs):
+ """
+ Updates C according to the L2 Loss kernel with the S Ts couplings
+ calculated at each iteration
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ list of the S spaces' weights
+ T : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Cs : list of S ndarray, shape(ns,ns)
+ Metric cost matrices
+ Returns
+ ----------
+ C : ndarray, shape (nt,nt)
+ updated C matrix
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return np.divide(tmpsum, ppt)
+
+def update_feature_matrix(lambdas,Ys,Ts,p):
+
+ """
+ Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3]
+ calculated at each iteration
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ list of the S spaces' weights
+ Ts : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Ys : list of S ndarray, shape(d,ns)
+ The features
+ Returns
+ ----------
+ X : ndarray, shape (d,N)
+
+ References
+ ----------
+ .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ p=np.diag(np.array(1/p).reshape(-1,))
+
+ tmpsum = sum([lambdas[s] * np.dot(Ys[s],Ts[s].T).dot(p) for s in range(len(Ts))])
+
+ return tmpsum
+
+