diff options
Diffstat (limited to 'ot/gromov/_semirelaxed.py')
-rw-r--r-- | ot/gromov/_semirelaxed.py | 591 |
1 files changed, 557 insertions, 34 deletions
diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 94dc975..206329d 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -18,7 +18,7 @@ from ..backend import get_backend from ._utils import init_matrix_semirelaxed, gwloss, gwggrad -def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None, +def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Returns the semi-relaxed Gromov-Wasserstein divergence transport from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` @@ -26,12 +26,12 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric= The function solves the following optimization problem: .. math:: - \mathbf{srGW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + \mathbf{T}^^* \in \mathop{\arg \min}_{\mathbf{T}} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - \mathbf{\gamma} &\geq 0 + \mathbf{T} &\geq 0 Where : @@ -51,8 +51,9 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric= Metric cost matrix in the source space C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. @@ -93,11 +94,16 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric= """ if loss_fun == 'kl_loss': raise NotImplementedError() - p = list_to_array(p) - if G0 is None: - nx = get_backend(p, C1, C2) + arr = [C1, C2] + if p is not None: + arr.append(list_to_array(p)) else: - nx = get_backend(p, C1, C2, G0) + p = unif(C1.shape[0], type_as=C1) + + if G0 is not None: + arr.append(G0) + + nx = get_backend(*arr) if symmetric is None: symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) @@ -143,7 +149,7 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric= return semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) -def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None, +def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Returns the semi-relaxed gromov-wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` @@ -151,12 +157,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric The function solves the following optimization problem: .. math:: - srGW = \min_\mathbf{T} \quad \sum_{i,j,k,l} + \text{srGW} = \min_{\mathbf{T}} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - \mathbf{\gamma} &\geq 0 + \mathbf{T} &\geq 0 Where : @@ -179,8 +185,9 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric Metric cost matrix in the source space C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : array-like, shape (ns,) + p : array-like, shape (ns,), optional Distribution in the source space. + If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. @@ -218,7 +225,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ - nx = get_backend(p, C1, C2) + # partial get_backend as the full one will be handled in gromov_wasserstein + nx = get_backend(C1, C2) + + # init marginals if set as None + if p is None: + p = unif(C1.shape[0], type_as=C1) T, log_srgw = semirelaxed_gromov_wasserstein( C1, C2, p, loss_fun, symmetric, log=True, G0=G0, @@ -239,18 +251,19 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric return srgw -def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, - max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): +def semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, + G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Computes the semi-relaxed FGW transport between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`) .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \mathbf{T}^* \in \mathop{\arg \min}_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - \mathbf{\gamma} &\geq 0 + \mathbf{T} &\geq 0 where : @@ -273,8 +286,9 @@ def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', s Metric cost matrix representative of the structure in the source space C2 : array-like, shape (nt, nt) Metric cost matrix representative of the structure in the target space - p : array-like, shape (ns,) - Distribution in the source space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. @@ -321,11 +335,16 @@ def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', s if loss_fun == 'kl_loss': raise NotImplementedError() - p = list_to_array(p) - if G0 is None: - nx = get_backend(p, C1, C2, M) + arr = [M, C1, C2] + if p is not None: + arr.append(list_to_array(p)) else: - nx = get_backend(p, C1, C2, M, G0) + p = unif(C1.shape[0], type_as=C1) + + if G0 is not None: + arr.append(G0) + + nx = get_backend(*arr) if symmetric is None: symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) @@ -373,18 +392,18 @@ def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', s return semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) -def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, +def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Computes the semi-relaxed FGW divergence between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein2>`) .. math:: - \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \mathbf{srFGW} = \min_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - \mathbf{\gamma} &\geq 0 + \mathbf{T} &\geq 0 where : @@ -412,6 +431,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', Metric cost matrix representative of the structure in the target space. p : array-like, shape (ns,) Distribution in the source space. + If let to its default value None, uniform distribution is taken. loss_fun : str, optional loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. @@ -455,7 +475,12 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ - nx = get_backend(p, C1, C2, M) + # partial get_backend as the full one will be handled in gromov_wasserstein + nx = get_backend(C1, C2) + + # init marginals if set as None + if p is None: + p = unif(C1.shape[0], type_as=C1) T, log_fgw = semirelaxed_fused_gromov_wasserstein( M, C1, C2, p, loss_fun, symmetric, alpha, G0, log=True, @@ -551,3 +576,501 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, cost_G = cost_G + a * (alpha ** 2) + b * alpha return alpha, 1, cost_G + + +def entropic_semirelaxed_gromov_wasserstein( + C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, + G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + r""" + Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence + transport plan from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` + estimated using a Mirror Descent algorithm following the KL geometry. + + The function solves the following optimization problem: + + .. math:: + \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T} &\geq 0 + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + + - `L`: loss function to account for the misfit between the similarity matrices + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + epsilon : float + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + verbose : bool, optional + Print information along iterations + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error computed on transport plans + log : bool, optional + record log if True + verbose : bool, optional + Print information along iterations + Returns + ------- + G : array-like, shape (`ns`, `nt`) + Coupling between the two spaces that minimizes: + + :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}` + log : dict + Convergence information and loss. + + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + if loss_fun == 'kl_loss': + raise NotImplementedError() + arr = [C1, C2] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(C1.shape[0], type_as=C1) + + if G0 is not None: + arr.append(G0) + + nx = get_backend(*arr) + + if symmetric is None: + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + if G0 is None: + q = unif(C2.shape[0], type_as=p) + G0 = nx.outer(p, q) + else: + q = nx.sum(G0, 0) + # Check first marginal of G0 + np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08) + + constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx) + + ones_p = nx.ones(p.shape[0], type_as=p) + + if symmetric: + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + return gwggrad(constC + marginal_product, hC1, hC2, G, nx) + else: + constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx) + + def df(G): + qG = nx.sum(G, 0) + marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t)) + marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) + return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + + cpt = 0 + err = 1e15 + G = G0 + + if log: + log = {'err': []} + + while (err > tol and cpt < max_iter): + + Gprev = G + # compute the kernel + K = G * nx.exp(- df(G) / epsilon) + scaling = p / nx.sum(K, 1) + G = nx.reshape(scaling, (-1, 1)) * K + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = nx.norm(G - Gprev) + + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + + if log: + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + log['srgw_dist'] = gwloss(constC + marginal_product, hC1, hC2, G, nx) + return G, log + else: + return G + + +def entropic_semirelaxed_gromov_wasserstein2( + C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, + G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + r""" + Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence + from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` + estimated using a Mirror Descent algorithm following the KL geometry. + + The function solves the following optimization problem: + + .. math:: + \mathbf{srGW} = \min_{\mathbf{T}} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T} &\geq 0 + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - `L`: loss function to account for the misfit between the similarity + matrices + + Note that when using backends, this loss function is differentiable wrt the + matrices (C1, C2) but not yet for the weights p. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + epsilon : float + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + verbose : bool, optional + Print information along iterations + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error computed on transport plans + log : bool, optional + record log if True + verbose : bool, optional + Print information along iterations + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + srgw : float + Semi-relaxed Gromov-Wasserstein divergence + log : dict + convergence information and Coupling matrix + + References + ---------- + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + T, log_srgw = entropic_semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun, epsilon, symmetric, G0, + max_iter, tol, log=True, verbose=verbose, **kwargs) + + log_srgw['T'] = T + + if log: + return log_srgw['srgw_dist'], log_srgw + else: + return log_srgw['srgw_dist'] + + +def entropic_semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1, + alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + r""" + Computes the entropic-regularized semi-relaxed FGW transport between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`) + + .. math:: + \mathbf{T}^* \in \mathop{\arg \min}_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T} &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix between features + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}` source weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>` + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix representative of the structure in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix representative of the structure in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + epsilon : float + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error computed on transport plans + log : bool, optional + record log if True + verbose : bool, optional + Print information along iterations + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + G : array-like, shape (`ns`, `nt`) + Optimal transportation matrix for the given parameters. + log : dict + Log dictionary return only if log==True in parameters. + + + .. _references-semirelaxed-fused-gromov-wasserstein: + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + if loss_fun == 'kl_loss': + raise NotImplementedError() + arr = [M, C1, C2] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(C1.shape[0], type_as=C1) + + if G0 is not None: + arr.append(G0) + + nx = get_backend(*arr) + + if symmetric is None: + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + if G0 is None: + q = unif(C2.shape[0], type_as=p) + G0 = nx.outer(p, q) + else: + q = nx.sum(G0, 0) + # Check first marginal of G0 + np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08) + + constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx) + + ones_p = nx.ones(p.shape[0], type_as=p) + dM = (1 - alpha) * M + if symmetric: + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + return alpha * gwggrad(constC + marginal_product, hC1, hC2, G, nx) + dM + else: + constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx) + + def df(G): + qG = nx.sum(G, 0) + marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t)) + marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) + return 0.5 * alpha * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + dM + + cpt = 0 + err = 1e15 + G = G0 + + if log: + log = {'err': []} + + while (err > tol and cpt < max_iter): + + Gprev = G + # compute the kernel + K = G * nx.exp(- df(G) / epsilon) + scaling = p / nx.sum(K, 1) + G = nx.reshape(scaling, (-1, 1)) * K + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = nx.norm(G - Gprev) + + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + + if log: + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + log['srfgw_dist'] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx) + (1 - alpha) * nx.sum(M * G) + return G, log + else: + return G + + +def entropic_semirelaxed_fused_gromov_wasserstein2( + M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1, + alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + r""" + Computes the entropic-regularized semi-relaxed FGW transport between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`) + + .. math:: + \mathbf{srFGW} = \min_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T} &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix between features + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}` source weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>` + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix representative of the structure in the source space. + C2 : array-like, shape (nt, nt) + Metric cost matrix representative of the structure in the target space. + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + loss_fun : str, optional + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + epsilon : float + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error computed on transport plans + log : bool, optional + record log if True + verbose : bool, optional + Print information along iterations + **kwargs : dict + Parameters can be directly passed to the ot.optim.cg solver. + + Returns + ------- + srfgw-divergence : float + Semi-relaxed Fused gromov wasserstein divergence for the given parameters. + log : dict + Log dictionary return only if log==True in parameters. + + + .. _references-semirelaxed-fused-gromov-wasserstein2: + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + T, log_srfgw = entropic_semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p, loss_fun, symmetric, epsilon, alpha, G0, + max_iter, tol, log=True, verbose=verbose, **kwargs) + + log_srfgw['T'] = T + + if log: + return log_srfgw['srfgw_dist'], log_srfgw + else: + return log_srfgw['srfgw_dist'] |