summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py206
1 files changed, 102 insertions, 104 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 699ae4c..4427a96 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
-Gromov-Wasserstein transport method
+Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers
"""
# Author: Erwan Vautier <erwan.vautier@gmail.com>
@@ -276,7 +276,6 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
- p : distribution in the source space
- q : distribution in the target space
- L : loss function to account for the misfit between the similarity matrices
- - H : entropy
Parameters
----------
@@ -343,6 +342,83 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
+ """
+ Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+
+ Where :
+ - C1 : Metric cost matrix in the source space
+ - C2 : Metric cost matrix in the target space
+ - p : distribution in the source space
+ - q : distribution in the target space
+ - L : loss function to account for the misfit between the similarity matrices
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space.
+ q : ndarray, shape (nt,)
+ Distribution in the target space.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+
+ Returns
+ -------
+ gw_dist : float
+ Gromov-Wasserstein distance
+ log : dict
+ convergence information and Coupling marix
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
+ metric approach to object matching. Foundations of computational
+ mathematics 11.4 (2011): 417-487.
+
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+ res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res)
+ log_gw['T'] = res
+ if log:
+ return log_gw['gw_dist'], log_gw
+ else:
+ return log_gw['gw_dist']
+
+
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
"""
Computes the FGW transport between two graphs see [24]
@@ -357,8 +433,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
where :
- M is the (ns,nt) metric cost matrix
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - p and q are source and target weights (sum to 1)
- L is a loss function to account for the misfit between the similarity matrices
The algorithm used for solving the problem is conditional gradient as discussed in [24]_
@@ -377,17 +452,13 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
Distribution in the target space
loss_fun : str, optional
Loss function used for the solver
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- record log if True
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
+ log : bool, optional
+ record log if True
**kwargs : dict
parameters can be directly passed to the ot.optim.cg solver
@@ -417,11 +488,11 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
return gwggrad(constC, hC1, hC2, G)
if log:
- res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
log['fgw_dist'] = log['loss'][::-1][0]
return res, log
else:
- return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
@@ -439,8 +510,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
where :
- M is the (ns,nt) metric cost matrix
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - p and q are source and target weights (sum to 1)
- L is a loss function to account for the misfit between the similarity matrices
The algorithm used for solving the problem is conditional gradient as discussed in [1]_
@@ -458,17 +528,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
Distribution in the target space.
loss_fun : str, optional
Loss function used for the solver.
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- Record log if True.
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
If True the steps of the line-search is found via an armijo research.
Else closed form is used. If there is convergence issues use False.
+ log : bool, optional
+ Record log if True.
**kwargs : dict
Parameters can be directly pased to the ot.optim.cg solver.
@@ -497,7 +563,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
if log:
log['fgw_dist'] = log['loss'][::-1][0]
log['T'] = res
@@ -506,84 +572,6 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
return log['fgw_dist']
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
- """
- Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
-
- The function solves the following optimization problem:
-
- .. math::
- GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
-
- Where :
- - C1 : Metric cost matrix in the source space
- - C2 : Metric cost matrix in the target space
- - p : distribution in the source space
- - q : distribution in the target space
- - L : loss function to account for the misfit between the similarity matrices
- - H : entropy
-
- Parameters
- ----------
- C1 : ndarray, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric cost matrix in the target space
- p : ndarray, shape (ns,)
- Distribution in the source space.
- q : ndarray, shape (nt,)
- Distribution in the target space.
- loss_fun : str
- loss function used for the solver either 'square_loss' or 'kl_loss'
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- record log if True
- armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
-
- Returns
- -------
- gw_dist : float
- Gromov-Wasserstein distance
- log : dict
- convergence information and Coupling marix
-
- References
- ----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
- metric approach to object matching. Foundations of computational
- mathematics 11.4 (2011): 417-487.
-
- """
-
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
-
- G0 = p[:, None] * q[None, :]
-
- def f(G):
- return gwloss(constC, hC1, hC2, G)
-
- def df(G):
- return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
- log['gw_dist'] = gwloss(constC, hC1, hC2, res)
- log['T'] = res
- if log:
- return log['gw_dist'], log
- else:
- return log['gw_dist']
-
-
def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
@@ -996,6 +984,16 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Whether to fix the structure of the barycenter during the updates
fixed_features : bool
Whether to fix the feature of the barycenter during the updates
+ loss_fun : str
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshol on error (>0).
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
init_C : ndarray, shape (N,N), optional
Initialization for the barycenters' structure matrix. If not set
a random init is used.
@@ -1084,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
T_temp = [t.T for t in T]
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
- T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
+ T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns