From a5930d3b3a446bf860d6dfacc1e17151fae1dd1d Mon Sep 17 00:00:00 2001 From: Cédric Vincent-Cuaz Date: Thu, 9 Mar 2023 14:21:33 +0100 Subject: [MRG] Semi-relaxed (fused) gromov-wasserstein divergence and improvements of gromov-wasserstein solvers (#431) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * maj gw/ srgw/ generic cg solver * correct pep8 on current state * fix bug previous tests * fix pep8 * fix bug srGW constC in loss and gradient * fix doc html * fix doc html * start updating test_optim.py * update tests gromov and optim - plus fix gromov dependencies * add symmetry feature to entropic gw * add symmetry feature to entropic gw * add exemple for sr(F)GW matchings * small stuff * remove (reg,M) from line-search/ complete srgw tests with backend * remove backend repetitions / rename fG to costG/ fix innerlog to True * fix pep8 * take comments into account / new nx parameters still to test * factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions * split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions * manual documentaion gromov * remove circular autosummary * trying stuff * debug documentation * alphabetic ordering of module * merge into branch * add note in entropic gw solvers --------- Co-authored-by: Rémi Flamary --- ot/gromov/_utils.py | 413 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 413 insertions(+) create mode 100644 ot/gromov/_utils.py (limited to 'ot/gromov/_utils.py') diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py new file mode 100644 index 0000000..e842250 --- /dev/null +++ b/ot/gromov/_utils.py @@ -0,0 +1,413 @@ +# -*- coding: utf-8 -*- +""" +Gromov-Wasserstein and Fused-Gromov-Wasserstein utils. +""" + +# Author: Erwan Vautier +# Nicolas Courty +# Rémi Flamary +# Titouan Vayer +# Cédric Vincent-Cuaz +# +# License: MIT License + + +from ..utils import list_to_array +from ..backend import get_backend + + +def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): + r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation + + Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the + selected loss function as the loss function of Gromow-Wasserstein discrepancy. + + The matrices are computed as described in Proposition 1 in :ref:`[12] ` + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{T}`: A coupling between those two spaces + + The square-loss function :math:`L(a, b) = |a - b|^2` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a^2 + + f_2(b) &= b^2 + + h_1(a) &= a + + h_2(b) &= 2b + + The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a \log(a) - a + + f_2(b) &= b + + h_1(a) &= a + + h_2(b) &= \log(b) + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Probability distribution in the source space + q : array-like, shape (nt,) + Probability distribution in the target space + loss_fun : str, optional + Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss') + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + + + .. _references-init-matrix: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + if nx is None: + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + if loss_fun == 'square_loss': + def f1(a): + return (a**2) + + def f2(b): + return (b**2) + + def h1(a): + return a + + def h2(b): + return 2 * b + elif loss_fun == 'kl_loss': + def f1(a): + return a * nx.log(a + 1e-15) - a + + def f2(b): + return b + + def h1(a): + return a + + def h2(b): + return nx.log(b + 1e-15) + + constC1 = nx.dot( + nx.dot(f1(C1), nx.reshape(p, (-1, 1))), + nx.ones((1, len(q)), type_as=q) + ) + constC2 = nx.dot( + nx.ones((len(p), 1), type_as=p), + nx.dot(nx.reshape(q, (1, -1)), f2(C2).T) + ) + constC = constC1 + constC2 + hC1 = h1(C1) + hC2 = h2(C2) + + return constC, hC1, hC2 + + +def tensor_product(constC, hC1, hC2, T, nx=None): + r"""Return the tensor for Gromov-Wasserstein fast computation + + The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` + + Parameters + ---------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + tens : array-like, shape (`ns`, `nt`) + :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result + + + .. _references-tensor-product: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + if nx is None: + constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T) + nx = get_backend(constC, hC1, hC2, T) + + A = - nx.dot( + nx.dot(hC1, T), hC2.T + ) + tens = constC + A + # tens -= tens.min() + return tens + + +def gwloss(constC, hC1, hC2, T, nx=None): + r"""Return the Loss for Gromov-Wasserstein + + The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` + + Parameters + ---------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + loss : float + Gromov Wasserstein loss + + + .. _references-gwloss: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + + tens = tensor_product(constC, hC1, hC2, T, nx) + if nx is None: + tens, T = list_to_array(tens, T) + nx = get_backend(tens, T) + + return nx.sum(tens * T) + + +def gwggrad(constC, hC1, hC2, T, nx=None): + r"""Return the gradient for Gromov-Wasserstein + + The gradient is computed as described in Proposition 2 in :ref:`[12] ` + + Parameters + ---------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + grad : array-like, shape (`ns`, `nt`) + Gromov Wasserstein gradient + + + .. _references-gwggrad: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + return 2 * tensor_product(constC, hC1, hC2, + T, nx) # [12] Prop. 2 misses a 2 factor + + +def update_square_loss(p, lambdas, T, Cs): + r""" + Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` + couplings calculated at each iteration + + Parameters + ---------- + p : array-like, shape (N,) + Masses in the targeted barycenter. + lambdas : list of float + List of the `S` spaces' weights. + T : list of S array-like of shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape(ns,ns) + Metric cost matrices. + + Returns + ---------- + C : array-like, shape (`nt`, `nt`) + Updated :math:`\mathbf{C}` matrix. + """ + T = list_to_array(*T) + Cs = list_to_array(*Cs) + p = list_to_array(p) + nx = get_backend(p, *T, *Cs) + + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + + return tmpsum / ppt + + +def update_kl_loss(p, lambdas, T, Cs): + r""" + Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration + + + Parameters + ---------- + p : array-like, shape (N,) + Weights in the targeted barycenter. + lambdas : list of float + List of the `S` spaces' weights + T : list of S array-like of shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape(ns,ns) + Metric cost matrices. + + Returns + ---------- + C : array-like, shape (`ns`, `ns`) + updated :math:`\mathbf{C}` matrix + """ + Cs = list_to_array(*Cs) + T = list_to_array(*T) + p = list_to_array(p) + nx = get_backend(p, *T, *Cs) + + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + + return nx.exp(tmpsum / ppt) + + +def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): + r"""Return loss matrices and tensors for semi-relaxed Gromov-Wasserstein fast computation + + Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the + selected loss function as the loss function of semi-relaxed Gromow-Wasserstein discrepancy. + + The matrices are computed as described in Proposition 1 in :ref:`[12] ` + and adapted to the semi-relaxed problem where the second marginal is not a constant anymore. + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{T}`: A coupling between those two spaces + + The square-loss function :math:`L(a, b) = |a - b|^2` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a^2 + + f_2(b) &= b^2 + + h_1(a) &= a + + h_2(b) &= 2b + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + T : array-like, shape (ns, nt) + Coupling between source and target spaces + p : array-like, shape (ns,) + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) adapted to srGW + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + fC2t: array-like, shape (nt, nt) + :math:`\mathbf{f2}(\mathbf{C2})^\top` matrix in Eq. (6) + + + .. _references-init-matrix: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + if nx is None: + C1, C2, p = list_to_array(C1, C2, p) + nx = get_backend(C1, C2, p) + + if loss_fun == 'square_loss': + def f1(a): + return (a**2) + + def f2(b): + return (b**2) + + def h1(a): + return a + + def h2(b): + return 2 * b + + constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))), + nx.ones((1, C2.shape[0]), type_as=p)) + + hC1 = h1(C1) + hC2 = h2(C2) + fC2t = f2(C2).T + return constC, hC1, hC2, fC2t -- cgit v1.2.3