summaryrefslogtreecommitdiff
path: root/ot/gromov/_gw.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gromov/_gw.py')
-rw-r--r--ot/gromov/_gw.py978
1 files changed, 978 insertions, 0 deletions
diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py
new file mode 100644
index 0000000..c6e4076
--- /dev/null
+++ b/ot/gromov/_gw.py
@@ -0,0 +1,978 @@
+# -*- coding: utf-8 -*-
+"""
+Gromov-Wasserstein and Fused-Gromov-Wasserstein conditional gradient solvers.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@unice.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+
+
+from ..utils import dist, UndefinedParameter, list_to_array
+from ..optim import cg, line_search_armijo, solve_1d_linesearch_quad
+from ..utils import check_random_state
+from ..backend import get_backend, NumpyBackend
+
+from ._utils import init_matrix, gwloss, gwggrad
+from ._utils import update_square_loss, update_kl_loss
+
+
+def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+ .. note:: All computations in the conjugate gradient solver are done with
+ numpy to limit memory overhead.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Coupling between the two spaces that minimizes:
+
+ :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}`
+ log : dict
+ Convergence information and loss.
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
+ metric approach to object matching. Foundations of computational
+ mathematics 11.4 (2011): 417-487.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ p, q = list_to_array(p, q)
+ p0, q0, C10, C20 = p, q, C1, C2
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, G0_)
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ if symmetric is None:
+ symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10)
+
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
+ # cg for GW is implemented using numpy on CPU
+ np_ = NumpyBackend()
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_)
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G, np_)
+
+ if symmetric:
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G, np_)
+ else:
+ constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, np_)
+
+ def df(G):
+ return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
+ if loss_fun == 'kl_loss':
+ armijo = True # there is no closed form line-search with KL
+
+ if armijo:
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
+ else:
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_, **kwargs)
+ if log:
+ res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+ log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
+ else:
+ return nx.from_numpy(cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10)
+
+
+def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity
+ matrices
+
+ Note that when using backends, this loss function is differentiable wrt the
+ matrices (C1, C2) and weights (p, q) for quadratic loss using the gradients from [38]_.
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+ .. note:: All computations in the conjugate gradient solver are done with
+ numpy to limit memory overhead.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ q : array-like, shape (nt,)
+ Distribution in the target space.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ gw_dist : float
+ Gromov-Wasserstein distance
+ log : dict
+ convergence information and Coupling marix
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
+ metric approach to object matching. Foundations of computational
+ mathematics 11.4 (2011): 417-487.
+
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ # simple get_backend as the full one will be handled in gromov_wasserstein
+ nx = get_backend(C1, C2)
+
+ T, log_gw = gromov_wasserstein(
+ C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0,
+ max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
+
+ log_gw['T'] = T
+ gw = log_gw['gw_dist']
+
+ if loss_fun == 'square_loss':
+ gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
+ gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
+ gw = nx.set_gradients(gw, (p, q, C1, C2),
+ (log_gw['u'] - nx.mean(log_gw['u']),
+ log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
+
+ if log:
+ return gw, log_gw
+ else:
+ return gw
+
+
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5,
+ armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+ .. note:: All computations in the conjugate gradient solver are done with
+ numpy to limit memory overhead.
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : str, optional
+ Loss function used for the solver
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ armijo : bool, optional
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ log : bool, optional
+ record log if True
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ gamma : array-like, shape (`ns`, `nt`)
+ Optimal transportation matrix for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-fused-gromov-wasserstein:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas "Optimal Transport for structured data with
+ application on graphs", International Conference on Machine Learning
+ (ICML). 2019.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ p, q = list_to_array(p, q)
+ p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20, M0)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, M0, G0_)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ M = nx.to_numpy(M0)
+
+ if symmetric is None:
+ symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10)
+
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
+ # cg for GW is implemented using numpy on CPU
+ np_ = NumpyBackend()
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_)
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G, np_)
+
+ if symmetric:
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G, np_)
+ else:
+ constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, np_)
+
+ def df(G):
+ return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
+
+ if loss_fun == 'kl_loss':
+ armijo = True # there is no closed form line-search with KL
+
+ if armijo:
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
+ else:
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
+ if log:
+ res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+ log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
+ else:
+ return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10)
+
+
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5,
+ armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
+
+ .. math::
+ \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ The algorithm used for solving the problem is conditional gradient as
+ discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+ .. note:: All computations in the conjugate gradient solver are done with
+ numpy to limit memory overhead.
+
+ Note that when using backends, this loss function is differentiable wrt the
+ matrices (C1, C2, M) and weights (p, q) for quadratic loss using the gradients from [38]_.
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space.
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space.
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ q : array-like, shape (nt,)
+ Distribution in the target space.
+ loss_fun : str, optional
+ Loss function used for the solver.
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ armijo : bool, optional
+ If True the step of the line-search is found via an armijo research.
+ Else closed form is used. If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ log : bool, optional
+ Record log if True.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ Parameters can be directly passed to the ot.optim.cg solver.
+
+ Returns
+ -------
+ fgw-distance : float
+ Fused gromov wasserstein distance for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-fused-gromov-wasserstein2:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ nx = get_backend(C1, C2, M)
+
+ T, log_fgw = fused_gromov_wasserstein(
+ M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True,
+ max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
+
+ fgw_dist = log_fgw['fgw_dist']
+ log_fgw['T'] = T
+
+ if loss_fun == 'square_loss':
+ gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
+ gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
+ fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
+ (log_fgw['u'] - nx.mean(log_fgw['u']),
+ log_fgw['v'] - nx.mean(log_fgw['v']),
+ alpha * gC1, alpha * gC2, (1 - alpha) * T))
+
+ if log:
+ return fgw_dist, log_fgw
+ else:
+ return fgw_dist
+
+
+def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
+ alpha_min=None, alpha_max=None, nx=None, **kwargs):
+ """
+ Solve the linesearch in the FW iterations
+
+ Parameters
+ ----------
+
+ G : array-like, shape(ns,nt)
+ The transport map at a given iteration of the FW
+ deltaG : array-like (ns,nt)
+ Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
+ cost_G : float
+ Value of the cost at `G`
+ C1 : array-like (ns,ns), optional
+ Structure matrix in the source domain.
+ C2 : array-like (nt,nt), optional
+ Structure matrix in the target domain.
+ M : array-like (ns,nt)
+ Cost matrix between the features.
+ reg : float
+ Regularization parameter.
+ alpha_min : float, optional
+ Minimum value for alpha
+ alpha_max : float, optional
+ Maximum value for alpha
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ alpha : float
+ The optimal step size of the FW
+ fc : int
+ nb of function call. Useless here
+ cost_G : float
+ The value of the cost for the next iteration
+
+
+ .. _references-solve-linesearch:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ if nx is None:
+ G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)
+
+ if isinstance(M, int) or isinstance(M, float):
+ nx = get_backend(G, deltaG, C1, C2)
+ else:
+ nx = get_backend(G, deltaG, C1, C2, M)
+
+ dot = nx.dot(nx.dot(C1, deltaG), C2.T)
+ a = -2 * reg * nx.sum(dot * deltaG)
+ b = nx.sum(M * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG))
+
+ alpha = solve_1d_linesearch_quad(a, b)
+ if alpha_min is not None or alpha_max is not None:
+ alpha = np.clip(alpha, alpha_min, alpha_max)
+
+ # the new cost is deduced from the line search quadratic function
+ cost_G = cost_G + a * (alpha ** 2) + b * alpha
+
+ return alpha, 1, cost_G
+
+
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=False,
+ max_iter=1000, tol=1e-9, verbose=False, log=False,
+ init_C=None, random_state=None, **kwargs):
+ r"""
+ Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
+
+ The function solves the following optimization problem with block coordinate descent:
+
+ .. math::
+
+ \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
+
+ Where :
+
+ - :math:`\mathbf{C}_s`: metric cost matrix
+ - :math:`\mathbf{p}_s`: distribution
+
+ Parameters
+ ----------
+ N : int
+ Size of the targeted barycenter
+ Cs : list of S array-like of shape (ns, ns)
+ Metric cost matrices
+ ps : list of S array-like of shape (ns,)
+ Sample weights in the `S` spaces
+ p : array-like, shape (N,)
+ Weights in the targeted barycenter
+ lambdas : list of float
+ List of the `S` spaces' weights
+ loss_fun : callable
+ tensor-matrix multiplication function based on specific loss function
+ symmetric : bool, optional.
+ Either structures are to be assumed symmetric or not. Default value is True.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
+ update : callable
+ function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
+ :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
+ calculated at each iteration
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on relative error (>0)
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
+ init_C : bool | array-like, shape(N,N)
+ Random initial value for the :math:`\mathbf{C}` matrix provided by user.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ C : array-like, shape (`N`, `N`)
+ Similarity matrix in the barycenter space (permutated arbitrarily)
+ log : dict
+ Log dictionary of error during iterations. Return only if `log=True` in parameters.
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *ps, p)
+
+ S = len(Cs)
+
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ C = nx.from_numpy(C, type_as=p)
+ else:
+ C = init_C
+
+ if loss_fun == 'kl_loss':
+ armijo = True
+
+ cpt = 0
+ err = 1
+
+ error = []
+
+ while (err > tol and cpt < max_iter):
+ Cprev = C
+
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo,
+ max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = nx.norm(C - Cprev)
+ error.append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ return C, {"err": error}
+ else:
+ return C
+
+
+def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
+ p=None, loss_fun='square_loss', armijo=False, symmetric=True, max_iter=100, tol=1e-9,
+ verbose=False, log=False, init_C=None, init_X=None, random_state=None, **kwargs):
+ r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
+
+ Parameters
+ ----------
+ N : int
+ Desired number of samples of the target barycenter
+ Ys: list of array-like, each element has shape (ns,d)
+ Features of all samples
+ Cs : list of array-like, each element has shape (ns,ns)
+ Structure matrices of all samples
+ ps : list of array-like, each element has shape (ns,)
+ Masses of all samples.
+ lambdas : list of float
+ List of the `S` spaces' weights
+ alpha : float
+ Alpha parameter for the fgw distance
+ fixed_structure : bool
+ Whether to fix the structure of the barycenter during the updates
+ fixed_features : bool
+ Whether to fix the feature of the barycenter during the updates
+ loss_fun : str
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ symmetric : bool, optional
+ Either structures are to be assumed symmetric or not. Default value is True.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on relative error (>0)
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
+ init_C : array-like, shape (N,N), optional
+ Initialization for the barycenters' structure matrix. If not set
+ a random init is used.
+ init_X : array-like, shape (N,d), optional
+ Initialization for the barycenters' features. If not set a
+ random init is used.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ X : array-like, shape (`N`, `d`)
+ Barycenters' features
+ C : array-like, shape (`N`, `N`)
+ Barycenters' structure matrix
+ log : dict
+ Only returned when log=True. It contains the keys:
+
+ - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
+ - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`)
+
+
+ .. _references-fgw-barycenters:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ Ys = list_to_array(*Ys)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *Ys, *ps)
+
+ S = len(Cs)
+ d = Ys[0].shape[1] # dimension on the node features
+ if p is None:
+ p = nx.ones(N, type_as=Cs[0]) / N
+
+ if fixed_structure:
+ if init_C is None:
+ raise UndefinedParameter('If C is fixed it must be initialized')
+ else:
+ C = init_C
+ else:
+ if init_C is None:
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
+ C = dist(xalea, xalea)
+ C = nx.from_numpy(C, type_as=ps[0])
+ else:
+ C = init_C
+
+ if fixed_features:
+ if init_X is None:
+ raise UndefinedParameter('If X is fixed it must be initialized')
+ else:
+ X = init_X
+ else:
+ if init_X is None:
+ X = nx.zeros((N, d), type_as=ps[0])
+ else:
+ X = init_X
+
+ T = [nx.outer(p, q) for q in ps]
+
+ Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
+
+ if loss_fun == 'kl_loss':
+ armijo = True
+
+ cpt = 0
+ err_feature = 1
+ err_structure = 1
+
+ if log:
+ log_ = {}
+ log_['err_feature'] = []
+ log_['err_structure'] = []
+ log_['Ts_iter'] = []
+
+ while ((err_feature > tol or err_structure > tol) and cpt < max_iter):
+ Cprev = C
+ Xprev = X
+
+ if not fixed_features:
+ Ys_temp = [y.T for y in Ys]
+ X = update_feature_matrix(lambdas, Ys_temp, T, p).T
+
+ Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
+
+ if not fixed_structure:
+ if loss_fun == 'square_loss':
+ T_temp = [t.T for t in T]
+ C = update_structure_matrix(p, lambdas, T_temp, Cs)
+
+ T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
+ max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
+
+ # T is N,ns
+ err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
+ err_structure = nx.norm(C - Cprev)
+ if log:
+ log_['err_feature'].append(err_feature)
+ log_['err_structure'].append(err_structure)
+ log_['Ts_iter'].append(T)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err_structure))
+ print('{:5d}|{:8e}|'.format(cpt, err_feature))
+
+ cpt += 1
+
+ if log:
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms
+
+ if log:
+ return X, C, log_
+ else:
+ return X, C
+
+
+def update_structure_matrix(p, lambdas, T, Cs):
+ r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
+
+ It is calculated at each iteration
+
+ Parameters
+ ----------
+ p : array-like, shape (N,)
+ Masses in the targeted barycenter.
+ lambdas : list of float
+ List of the `S` spaces' weights.
+ T : list of S array-like of shape (ns, N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape (ns, ns)
+ Metric cost matrices.
+
+ Returns
+ -------
+ C : array-like, shape (`nt`, `nt`)
+ Updated :math:`\mathbf{C}` matrix.
+ """
+ p = list_to_array(p)
+ T = list_to_array(*T)
+ Cs = list_to_array(*Cs)
+ nx = get_backend(*Cs, *T, p)
+
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+ return tmpsum / ppt
+
+
+def update_feature_matrix(lambdas, Ys, Ts, p):
+ r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
+
+
+ See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
+ in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration
+
+ Parameters
+ ----------
+ p : array-like, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ List of the `S` spaces' weights
+ Ts : list of S array-like, shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
+ Ys : list of S array-like, shape (d,ns)
+ The features.
+
+ Returns
+ -------
+ X : array-like, shape (`d`, `N`)
+
+
+ .. _references-update-feature-matrix:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ p = list_to_array(p)
+ Ts = list_to_array(*Ts)
+ Ys = list_to_array(*Ys)
+ nx = get_backend(*Ys, *Ts, p)
+
+ p = 1. / p
+ tmpsum = sum([
+ lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :]
+ for s in range(len(Ts))
+ ])
+ return tmpsum