summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 17:05:38 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 17:05:48 +0200
commite1bd94bb7e85a0d2fd0fcd7642b06da12c1db6db (patch)
tree1e85920b878ab715d211db56f99e25bfa2482fd3 /ot/gromov.py
parentd4320382fa8873d15dcaec7adca3a4723c142515 (diff)
code review1
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py108
1 files changed, 96 insertions, 12 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 5a57dc8..53349b7 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -10,6 +10,7 @@ Gromov-Wasserstein transport method
# Nicolas Courty <ncourty@irisa.fr>
# RĂ©mi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+#
# License: MIT License
import numpy as np
@@ -351,9 +352,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
-def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, **kwargs):
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
"""
- Computes the FGW distance between two graphs see [3]
+ Computes the FGW transport between two graphs see [24]
.. math::
\gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
s.t. \gamma 1 = p
@@ -377,7 +378,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
distribution in the source space
q : ndarray, shape (nt,)
distribution in the target space
- loss_fun : string,optionnal
+ loss_fun : string,optional
loss function used for the solver
max_iter : int, optional
Max number of iterations
@@ -416,7 +417,86 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
def df(G):
return gwggrad(constC, hC1, hC2, G)
- return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ if log:
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ log['fgw_dist'] = log['loss'][::-1][0]
+ return res, log
+ else:
+ return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+ """
+ Computes the FGW distance between two graphs see [24]
+ .. math::
+ \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ s.t. \gamma 1 = p
+ \gamma^T 1= q
+ \gamma\geq 0
+ where :
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+ - L is a loss function to account for the misfit between the similarity matrices
+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+ Parameters
+ ----------
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix respresentative of the structure in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix espresentative of the structure in the target space
+ p : ndarray, shape (ns,)
+ distribution in the source space
+ q : ndarray, shape (nt,)
+ distribution in the target space
+ loss_fun : string,optional
+ loss function used for the solver
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly pased to the ot.optim.cg solver
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ if log:
+ log['fgw_dist'] = log['loss'][::-1][0]
+ log['T'] = res
+ return log['fgw_dist'], log
+ else:
+ return log['fgw_dist']
def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
@@ -889,7 +969,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
- verbose=False, log=True, init_C=None, init_X=None):
+ verbose=False, log=False, init_C=None, init_X=None):
"""
Compute the fgw barycenter as presented eq (5) in [24].
----------
@@ -919,7 +999,8 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Barycenters' features
C : ndarray, shape (N,N)
Barycenters' structure matrix
- log_:
+ log_: dictionary
+ Only returned when log=True
T : list of (N,ns) transport matrices
Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns)
References
@@ -1015,14 +1096,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
-
- log_['Ts_iter'].append(T)
err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
err_structure = np.linalg.norm(C - Cprev)
if log:
log_['err_feature'].append(err_feature)
log_['err_structure'].append(err_structure)
+ log_['Ts_iter'].append(T)
if verbose:
if cpt % 200 == 0:
@@ -1032,11 +1112,15 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
print('{:5d}|{:8e}|'.format(cpt, err_feature))
cpt += 1
- log_['T'] = T # from target to Ys
- log_['p'] = p
- log_['Ms'] = Ms # Ms are N,ns
+ if log:
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms # Ms are N,ns
- return X, C, log_
+ if log:
+ return X, C, log_
+ else:
+ return X, C
def update_sructure_matrix(p, lambdas, T, Cs):