summaryrefslogtreecommitdiff
path: root/ot/gromov/_semirelaxed.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gromov/_semirelaxed.py')
-rw-r--r--ot/gromov/_semirelaxed.py591
1 files changed, 557 insertions, 34 deletions
diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py
index 94dc975..206329d 100644
--- a/ot/gromov/_semirelaxed.py
+++ b/ot/gromov/_semirelaxed.py
@@ -18,7 +18,7 @@ from ..backend import get_backend
from ._utils import init_matrix_semirelaxed, gwloss, gwggrad
-def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None,
+def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None,
max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Returns the semi-relaxed Gromov-Wasserstein divergence transport from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}`
@@ -26,12 +26,12 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=
The function solves the following optimization problem:
.. math::
- \mathbf{srGW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ \mathbf{T}^^* \in \mathop{\arg \min}_{\mathbf{T}} \quad \sum_{i,j,k,l}
L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
- \mathbf{\gamma} &\geq 0
+ \mathbf{T} &\geq 0
Where :
@@ -51,8 +51,9 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
@@ -93,11 +94,16 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=
"""
if loss_fun == 'kl_loss':
raise NotImplementedError()
- p = list_to_array(p)
- if G0 is None:
- nx = get_backend(p, C1, C2)
+ arr = [C1, C2]
+ if p is not None:
+ arr.append(list_to_array(p))
else:
- nx = get_backend(p, C1, C2, G0)
+ p = unif(C1.shape[0], type_as=C1)
+
+ if G0 is not None:
+ arr.append(G0)
+
+ nx = get_backend(*arr)
if symmetric is None:
symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
@@ -143,7 +149,7 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=
return semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
-def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None,
+def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None,
max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Returns the semi-relaxed gromov-wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}`
@@ -151,12 +157,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric
The function solves the following optimization problem:
.. math::
- srGW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
+ \text{srGW} = \min_{\mathbf{T}} \quad \sum_{i,j,k,l}
L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
- \mathbf{\gamma} &\geq 0
+ \mathbf{T} &\geq 0
Where :
@@ -179,8 +185,9 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
- p : array-like, shape (ns,)
+ p : array-like, shape (ns,), optional
Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
@@ -218,7 +225,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
- nx = get_backend(p, C1, C2)
+ # partial get_backend as the full one will be handled in gromov_wasserstein
+ nx = get_backend(C1, C2)
+
+ # init marginals if set as None
+ if p is None:
+ p = unif(C1.shape[0], type_as=C1)
T, log_srgw = semirelaxed_gromov_wasserstein(
C1, C2, p, loss_fun, symmetric, log=True, G0=G0,
@@ -239,18 +251,19 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric
return srgw
-def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False,
- max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+def semirelaxed_fused_gromov_wasserstein(
+ M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5,
+ G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Computes the semi-relaxed FGW transport between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`)
.. math::
- \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
- \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+ \mathbf{T}^* \in \mathop{\arg \min}_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l}
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
- \mathbf{\gamma} &\geq 0
+ \mathbf{T} &\geq 0
where :
@@ -273,8 +286,9 @@ def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', s
Metric cost matrix representative of the structure in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix representative of the structure in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
@@ -321,11 +335,16 @@ def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', s
if loss_fun == 'kl_loss':
raise NotImplementedError()
- p = list_to_array(p)
- if G0 is None:
- nx = get_backend(p, C1, C2, M)
+ arr = [M, C1, C2]
+ if p is not None:
+ arr.append(list_to_array(p))
else:
- nx = get_backend(p, C1, C2, M, G0)
+ p = unif(C1.shape[0], type_as=C1)
+
+ if G0 is not None:
+ arr.append(G0)
+
+ nx = get_backend(*arr)
if symmetric is None:
symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
@@ -373,18 +392,18 @@ def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', s
return semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
-def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False,
+def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False,
max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Computes the semi-relaxed FGW divergence between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein2>`)
.. math::
- \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
- L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+ \mathbf{srFGW} = \min_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l}
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
- \mathbf{\gamma} &\geq 0
+ \mathbf{T} &\geq 0
where :
@@ -412,6 +431,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss',
Metric cost matrix representative of the structure in the target space.
p : array-like, shape (ns,)
Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str, optional
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
@@ -455,7 +475,12 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss',
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
- nx = get_backend(p, C1, C2, M)
+ # partial get_backend as the full one will be handled in gromov_wasserstein
+ nx = get_backend(C1, C2)
+
+ # init marginals if set as None
+ if p is None:
+ p = unif(C1.shape[0], type_as=C1)
T, log_fgw = semirelaxed_fused_gromov_wasserstein(
M, C1, C2, p, loss_fun, symmetric, alpha, G0, log=True,
@@ -551,3 +576,501 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
cost_G = cost_G + a * (alpha ** 2) + b * alpha
return alpha, 1, cost_G
+
+
+def entropic_semirelaxed_gromov_wasserstein(
+ C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None,
+ G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs):
+ r"""
+ Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence
+ transport plan from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}`
+ estimated using a Mirror Descent algorithm following the KL geometry.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error computed on transport plans
+ log : bool, optional
+ record log if True
+ verbose : bool, optional
+ Print information along iterations
+ Returns
+ -------
+ G : array-like, shape (`ns`, `nt`)
+ Coupling between the two spaces that minimizes:
+
+ :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}`
+ log : dict
+ Convergence information and loss.
+
+ References
+ ----------
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ if loss_fun == 'kl_loss':
+ raise NotImplementedError()
+ arr = [C1, C2]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(C1.shape[0], type_as=C1)
+
+ if G0 is not None:
+ arr.append(G0)
+
+ nx = get_backend(*arr)
+
+ if symmetric is None:
+ symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
+ if G0 is None:
+ q = unif(C2.shape[0], type_as=p)
+ G0 = nx.outer(p, q)
+ else:
+ q = nx.sum(G0, 0)
+ # Check first marginal of G0
+ np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
+
+ constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
+
+ ones_p = nx.ones(p.shape[0], type_as=p)
+
+ if symmetric:
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ return gwggrad(constC + marginal_product, hC1, hC2, G, nx)
+ else:
+ constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t))
+ marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2))
+ return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx))
+
+ cpt = 0
+ err = 1e15
+ G = G0
+
+ if log:
+ log = {'err': []}
+
+ while (err > tol and cpt < max_iter):
+
+ Gprev = G
+ # compute the kernel
+ K = G * nx.exp(- df(G) / epsilon)
+ scaling = p / nx.sum(K, 1)
+ G = nx.reshape(scaling, (-1, 1)) * K
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = nx.norm(G - Gprev)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ log['srgw_dist'] = gwloss(constC + marginal_product, hC1, hC2, G, nx)
+ return G, log
+ else:
+ return G
+
+
+def entropic_semirelaxed_gromov_wasserstein2(
+ C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None,
+ G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs):
+ r"""
+ Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence
+ from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}`
+ estimated using a Mirror Descent algorithm following the KL geometry.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{srGW} = \min_{\mathbf{T}} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - `L`: loss function to account for the misfit between the similarity
+ matrices
+
+ Note that when using backends, this loss function is differentiable wrt the
+ matrices (C1, C2) but not yet for the weights p.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error computed on transport plans
+ log : bool, optional
+ record log if True
+ verbose : bool, optional
+ Print information along iterations
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ srgw : float
+ Semi-relaxed Gromov-Wasserstein divergence
+ log : dict
+ convergence information and Coupling matrix
+
+ References
+ ----------
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ T, log_srgw = entropic_semirelaxed_gromov_wasserstein(
+ C1, C2, p, loss_fun, epsilon, symmetric, G0,
+ max_iter, tol, log=True, verbose=verbose, **kwargs)
+
+ log_srgw['T'] = T
+
+ if log:
+ return log_srgw['srgw_dist'], log_srgw
+ else:
+ return log_srgw['srgw_dist']
+
+
+def entropic_semirelaxed_fused_gromov_wasserstein(
+ M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1,
+ alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs):
+ r"""
+ Computes the entropic-regularized semi-relaxed FGW transport between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`)
+
+ .. math::
+ \mathbf{T}^* \in \mathop{\arg \min}_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix between features
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}` source weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error computed on transport plans
+ log : bool, optional
+ record log if True
+ verbose : bool, optional
+ Print information along iterations
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ G : array-like, shape (`ns`, `nt`)
+ Optimal transportation matrix for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-semirelaxed-fused-gromov-wasserstein:
+ References
+ ----------
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ if loss_fun == 'kl_loss':
+ raise NotImplementedError()
+ arr = [M, C1, C2]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(C1.shape[0], type_as=C1)
+
+ if G0 is not None:
+ arr.append(G0)
+
+ nx = get_backend(*arr)
+
+ if symmetric is None:
+ symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
+ if G0 is None:
+ q = unif(C2.shape[0], type_as=p)
+ G0 = nx.outer(p, q)
+ else:
+ q = nx.sum(G0, 0)
+ # Check first marginal of G0
+ np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
+
+ constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
+
+ ones_p = nx.ones(p.shape[0], type_as=p)
+ dM = (1 - alpha) * M
+ if symmetric:
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ return alpha * gwggrad(constC + marginal_product, hC1, hC2, G, nx) + dM
+ else:
+ constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t))
+ marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2))
+ return 0.5 * alpha * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + dM
+
+ cpt = 0
+ err = 1e15
+ G = G0
+
+ if log:
+ log = {'err': []}
+
+ while (err > tol and cpt < max_iter):
+
+ Gprev = G
+ # compute the kernel
+ K = G * nx.exp(- df(G) / epsilon)
+ scaling = p / nx.sum(K, 1)
+ G = nx.reshape(scaling, (-1, 1)) * K
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = nx.norm(G - Gprev)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ log['srfgw_dist'] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx) + (1 - alpha) * nx.sum(M * G)
+ return G, log
+ else:
+ return G
+
+
+def entropic_semirelaxed_fused_gromov_wasserstein2(
+ M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1,
+ alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs):
+ r"""
+ Computes the entropic-regularized semi-relaxed FGW transport between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`)
+
+ .. math::
+ \mathbf{srFGW} = \min_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix between features
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}` source weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space.
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : str, optional
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error computed on transport plans
+ log : bool, optional
+ record log if True
+ verbose : bool, optional
+ Print information along iterations
+ **kwargs : dict
+ Parameters can be directly passed to the ot.optim.cg solver.
+
+ Returns
+ -------
+ srfgw-divergence : float
+ Semi-relaxed Fused gromov wasserstein divergence for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-semirelaxed-fused-gromov-wasserstein2:
+ References
+ ----------
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ T, log_srfgw = entropic_semirelaxed_fused_gromov_wasserstein(
+ M, C1, C2, p, loss_fun, symmetric, epsilon, alpha, G0,
+ max_iter, tol, log=True, verbose=verbose, **kwargs)
+
+ log_srfgw['T'] = T
+
+ if log:
+ return log_srfgw['srfgw_dist'], log_srfgw
+ else:
+ return log_srfgw['srfgw_dist']