From ccc076e0fc535b2c734214c0ac1936e9e2cbeb62 Mon Sep 17 00:00:00 2001 From: eloitanguy <69361683+eloitanguy@users.noreply.github.com> Date: Fri, 6 May 2022 08:43:21 +0200 Subject: [WIP] Generalized Wasserstein Barycenters (#372) * GWB first solver version * tests + example for gwb (untested) + free_bar doc fix * improved doc, fixed minor bugs, better example visu * minor doc + visu fixes * plot GWB pep8 fix * fixed partial gromov test reproductibility * added an animation for the GWB visu * added PR num * minor doc fixes + better gwb logo --- ot/lp/__init__.py | 145 ++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 125 insertions(+), 20 deletions(-) (limited to 'ot/lp/__init__.py') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 390c32d..572781d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -26,10 +26,8 @@ from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend - - __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', - 'emd_1d', 'emd2_1d', 'wasserstein_1d'] + 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter'] def check_number_threads(numThreads): @@ -517,8 +515,8 @@ def emd2(a, b, M, processes=1, log['warning'] = result_code_string log['result_code'] = result_code cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (log['u'] - nx.mean(log['u']), - log['v'] - nx.mean(log['v']), G)) + (a0, b0, M0), (log['u'] - nx.mean(log['u']), + log['v'] - nx.mean(log['v']), G)) return [cost, log] else: def f(b): @@ -572,18 +570,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None where : - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one - - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i` - - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations + - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) + - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter - This problem is considered in :ref:`[1] ` (Algorithm 2). + This problem is considered in :ref:`[20] ` (Algorithm 2). There are two differences with the following codes: - we do not optimize over the weights - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in - :ref:`[1] ` (Algorithm 2). This can be seen as a discrete + :ref:`[20] ` (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of - :ref:`[2] ` proposed in the continuous setting. + :ref:`[43] ` proposed in the continuous setting. Parameters ---------- @@ -623,13 +621,13 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None .. _references-free-support-barycenter: References ---------- - .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. """ - nx = get_backend(*measures_locations,*measures_weights,X_init) + nx = get_backend(*measures_locations, *measures_weights, X_init) iter_count = 0 @@ -637,9 +635,9 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None k = X_init.shape[0] d = X_init.shape[1] if b is None: - b = nx.ones((k,),type_as=X_init) / k + b = nx.ones((k,), type_as=X_init) / k if weights is None: - weights = nx.ones((N,),type_as=X_init) / N + weights = nx.ones((N,), type_as=X_init) / N X = X_init @@ -650,15 +648,14 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None while (displacement_square_norm > stopThr and iter_count < numItermax): - T_sum = nx.zeros((k, d),type_as=X_init) - + T_sum = nx.zeros((k, d), type_as=X_init) - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): M_i = dist(X, measure_locations_i) T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) - T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i) + T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i) - displacement_square_norm = nx.sum((T_sum - X)**2) + displacement_square_norm = nx.sum((T_sum - X) ** 2) if log: displacement_square_norms.append(displacement_square_norm) @@ -675,3 +672,111 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None else: return X + +def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, Y_init=None, b=None, weights=None, + numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1, eps=0): + r""" + Solves the free support generalised Wasserstein barycenter problem: finding a barycenter (a discrete measure with + a fixed amount of points of uniform weights) whose respective projections fit the input measures. + More formally: + + .. math:: + \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma) + + where : + + - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d` + - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter + - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}` + - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex) + - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations + - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex) + - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}` + + As show by :ref:`[42] `, + this problem can be re-written as a Wasserstein Barycenter problem, + which we solve using the free support method :ref:`[20] ` + (Algorithm 2). + + Parameters + ---------- + X_list : list of p (k_i,d_i) array-like + Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space + (:math:`k_i` can be different for each element of the list) + a_list : list of p (k_i,) array-like + Measure weights: each element is a vector (k_i) on the simplex + P_list : list of p (d_i,d) array-like + Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}` + n_samples_bary : int + Number of barycenter points + Y_init : (n_samples_bary,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + b : (n_samples_bary,) array-like + Initialization of the weights of the barycenter measure (on the simplex) + weights : (p,) array-like + Initialization of the coefficients of the barycenter (on the simplex) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + eps: Stability coefficient for the change of variable matrix inversion + If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix + inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense) + + + Returns + ------- + Y : (n_samples_bary,d) array-like + Support locations (on n_samples_bary atoms) of the barycenter + + + .. _references-generalized-free-support-barycenter: + References + ---------- + .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. + + """ + nx = get_backend(*X_list, *a_list, *P_list) + d = P_list[0].shape[1] + p = len(X_list) + + if weights is None: + weights = nx.ones(p, type_as=X_list[0]) / p + + # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB) + A = eps * nx.eye(d, type_as=X_list[0]) # if eps nonzero: will force the invertibility of A + for (P_i, lambda_i) in zip(P_list, weights): + A = A + lambda_i * P_i.T @ P_i + B = nx.inv(nx.sqrtm(A)) + + Z_list = [x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list)] # change of variables -> (WB) problem on Z + + if Y_init is None: + Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0]) + + if b is None: + b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimised + + out = free_support_barycenter(Z_list, a_list, Y_init, b, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, numThreads=numThreads) + + if log: # unpack + Y, log_dict = out + else: + Y = out + log_dict = None + Y = Y @ B.T # return to the Generalised WB formulation + + if log: + return Y, log_dict + else: + return Y -- cgit v1.2.3 From 7c2a9523747c90aebfef711fdf34b5bbdb6f2f4d Mon Sep 17 00:00:00 2001 From: clecoz Date: Tue, 21 Jun 2022 17:36:22 +0200 Subject: [MRG] raise error if mass mismatch in emd2 (#386) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Two lines added in the function emd2 to ensure that the distributions have the same mass (same as it already was in the function emd). * The same mass test has been moved inside the function f(b) to be compatible with emd2 with multiple b. * Test added. The function test_emd_dimension_and_mass_mismatch (in test/test_ot.py) has been modified to check for mass mismatch with emd2. * Add PR in releases.md * Merge and add PR in releases.md * Add name in contributors.md * Correction contribution in contributors.md * Move test on mass outside of functions f(b) * Update doc of emd and emd2 Co-authored-by: Camille Le Coz Co-authored-by: Rémi Flamary --- CONTRIBUTORS.md | 1 + RELEASES.md | 1 + ot/lp/__init__.py | 9 +++++++++ test/test_ot.py | 3 +++ 4 files changed, 14 insertions(+) (limited to 'ot/lp/__init__.py') diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 0909b14..c535c09 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -38,6 +38,7 @@ The contributors to this library are: * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) +* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) ## Acknowledgments diff --git a/RELEASES.md b/RELEASES.md index b384617..78a7d9e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -17,6 +17,7 @@ - Fixed an issue where pointers would overflow in the EMD solver, returning an incomplete transport plan above a certain size (slightly above 46k, its square being roughly 2^31) (PR #381) +- Error raised when mass mismatch in emd2 (PR #386) ## 0.8.2 diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 572781d..17411d0 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -230,6 +230,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): If this behaviour is unwanted, please make sure to provide a floating point input. + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + Uses the algorithm proposed in :ref:`[1] `. Parameters @@ -389,6 +391,8 @@ def emd2(a, b, M, processes=1, If this behaviour is unwanted, please make sure to provide a floating point input. + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + Uses the algorithm proposed in :ref:`[1] `. Parameters @@ -481,6 +485,11 @@ def emd2(a, b, M, processes=1, assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0), + b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum') + b = b * a.sum(0) / b.sum(0,keepdims=True) + asel = a != 0 numThreads = check_number_threads(numThreads) diff --git a/test/test_ot.py b/test/test_ot.py index ba3ef6a..9a4e175 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -29,9 +29,12 @@ def test_emd_dimension_and_mass_mismatch(): np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + # test emd and emd2 for mass mismatch + a = ot.utils.unif(n_samples) b = a.copy() a[0] = 100 np.testing.assert_raises(AssertionError, ot.emd, a, b, M) + np.testing.assert_raises(AssertionError, ot.emd2, a, b, M) def test_emd_backends(nx): -- cgit v1.2.3 From 80e3c23bc968f866fd20344ddc443a3c7fcb3b0d Mon Sep 17 00:00:00 2001 From: Clément Bonet <32179275+clbonet@users.noreply.github.com> Date: Thu, 23 Feb 2023 08:31:01 +0100 Subject: [WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434) * W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn --- CONTRIBUTORS.md | 1 + README.md | 10 +- RELEASES.md | 4 + examples/backends/plot_ssw_unif_torch.py | 153 ++++++ examples/plot_compute_wasserstein_circle.py | 161 ++++++ examples/sliced-wasserstein/plot_variance_ssw.py | 111 ++++ ot/__init__.py | 13 +- ot/backend.py | 204 +++++++- ot/lp/__init__.py | 7 +- ot/lp/solver_1d.py | 627 ++++++++++++++++++++++- ot/sliced.py | 185 ++++++- ot/utils.py | 30 ++ test/test_1d_solver.py | 127 +++++ test/test_backend.py | 46 ++ test/test_sliced.py | 186 +++++++ test/test_utils.py | 10 + 16 files changed, 1852 insertions(+), 23 deletions(-) create mode 100644 examples/backends/plot_ssw_unif_torch.py create mode 100644 examples/plot_compute_wasserstein_circle.py create mode 100644 examples/sliced-wasserstein/plot_variance_ssw.py (limited to 'ot/lp/__init__.py') diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 67d8337..1437821 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -41,6 +41,7 @@ The contributors to this library are: * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) +* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein) ## Acknowledgments diff --git a/README.md b/README.md index 7c9475b..d5e6854 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ POT provides the following generic OT solvers (links to examples): * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. +* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45] +* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] * [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. @@ -292,4 +294,10 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. -[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. \ No newline at end of file +[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + +[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + +[45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + +[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 4ed3625..f8ef653 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,6 +4,10 @@ #### New features +- Added the spherical sliced-Wasserstein discrepancy in `ot.sliced.sliced_wasserstein_sphere` and `ot.sliced.sliced_wasserstein_sphere_unif` + examples (PR #434) +- Added the Wasserstein distance on the circle in ``ot.lp.solver_1d.wasserstein_circle`` (PR #434) +- Added the Wasserstein distance on the circle (for p>=1) in `ot.lp.solver_1d.binary_search_circle` + examples (PR #434) +- Added the 2-Wasserstein distance on the circle w.r.t a uniform distribution in `ot.lp.solver_1d.semidiscrete_wasserstein2_unif_circle` (PR #434) - Added Bures Wasserstein distance in `ot.gaussian` (PR ##428) - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) - Added Free Support Sinkhorn Barycenter + example (PR #387) diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py new file mode 100644 index 0000000..d1de5a9 --- /dev/null +++ b/examples/backends/plot_ssw_unif_torch.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +r""" +================================================ +Spherical Sliced-Wasserstein Embedding on Sphere +================================================ + +Here, we aim at transforming samples into a uniform +distribution on the sphere by minimizing SSW: + +.. math:: + \min_{x} SSW_2(\nu, \frac{1}{n}\sum_{i=1}^n \delta_{x_i}) + +where :math:`\nu=\mathrm{Unif}(S^1)`. + +""" + +# Author: Clément Bonet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import matplotlib.animation as animation +import torch +import torch.nn.functional as F + +import ot + + +# %% +# Data generation +# --------------- + +torch.manual_seed(1) + +N = 1000 +x0 = torch.rand(N, 3) +x0 = F.normalize(x0, dim=-1) + + +# %% +# Plot data +# --------- + +def plot_sphere(ax): + xlist = np.linspace(-1.0, 1.0, 50) + ylist = np.linspace(-1.0, 1.0, 50) + r = np.linspace(1.0, 1.0, 50) + X, Y = np.meshgrid(xlist, ylist) + + Z = np.sqrt(r**2 - X**2 - Y**2) + + ax.plot_wireframe(X, Y, Z, color="gray", alpha=.3) + ax.plot_wireframe(X, Y, -Z, color="gray", alpha=.3) # Now plot the bottom half + + +# plot the distributions +pl.figure(1) +ax = pl.axes(projection='3d') +plot_sphere(ax) +ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label='Data samples', alpha=0.5) +ax.set_title('Data distribution') +ax.legend() + + +# %% +# Gradient descent +# ---------------- + +x = x0.clone() +x.requires_grad_(True) + +n_iter = 500 +lr = 100 + +losses = [] +xvisu = torch.zeros(n_iter, N, 3) + +for i in range(n_iter): + sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500) + grad_x = torch.autograd.grad(sw, x)[0] + + x = x - lr * grad_x + x = F.normalize(x, p=2, dim=1) + + losses.append(sw.item()) + xvisu[i, :, :] = x.detach().clone() + + if i % 100 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + +pl.figure(1) +pl.semilogy(losses) +pl.grid() +pl.title('SSW') +pl.xlabel("Iterations") + + +# %% +# Plot trajectories of generated samples along iterations +# ------------------------------------------------------- + +ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499] + +fig = pl.figure(3, (10, 10)) +for i in range(9): + # pl.subplot(3, 3, i + 1) + # ax = pl.axes(projection='3d') + ax = fig.add_subplot(3, 3, i + 1, projection='3d') + plot_sphere(ax) + ax.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], xvisu[ivisu[i], :, 2], label='Data samples', alpha=0.5) + ax.set_title('Iter. {}'.format(ivisu[i])) + #ax.axis("off") + if i == 0: + ax.legend() + + +# %% +# Animate trajectories of generated samples along iteration +# ------------------------------------------------------- + +pl.figure(4, (8, 8)) + + +def _update_plot(i): + i = 3 * i + pl.clf() + ax = pl.axes(projection='3d') + plot_sphere(ax) + ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples$', alpha=0.5) + ax.axis("off") + ax.set_xlim((-1.5, 1.5)) + ax.set_ylim((-1.5, 1.5)) + ax.set_title('Iter. {}'.format(i)) + return 1 + + +print(xvisu.shape) + +i = 0 +ax = pl.axes(projection='3d') +plot_sphere(ax) +ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples from $G\#\mu_n$', alpha=0.5) +ax.axis("off") +ax.set_xlim((-1.5, 1.5)) +ax.set_ylim((-1.5, 1.5)) +ax.set_title('Iter. {}'.format(ivisu[i])) + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000) +# %% diff --git a/examples/plot_compute_wasserstein_circle.py b/examples/plot_compute_wasserstein_circle.py new file mode 100644 index 0000000..3ede96f --- /dev/null +++ b/examples/plot_compute_wasserstein_circle.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +""" +========================= +OT distance on the Circle +========================= + +Shows how to compute the Wasserstein distance on the circle + + +""" + +# Author: Clément Bonet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot + +from scipy.special import iv + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + + +def pdf_von_Mises(theta, mu, kappa): + pdf = np.exp(kappa * np.cos(theta - mu)) / (2.0 * np.pi * iv(0, kappa)) + return pdf + + +t = np.linspace(0, 2 * np.pi, 1000, endpoint=False) + +mu1 = 1 +kappa1 = 20 + +mu_targets = np.linspace(mu1, mu1 + 2 * np.pi, 10) + + +pdf1 = pdf_von_Mises(t, mu1, kappa1) + + +pl.figure(1) +for k, mu in enumerate(mu_targets): + pdf_t = pdf_von_Mises(t, mu, kappa1) + if k == 0: + label = "Source distributions" + else: + label = None + pl.plot(t / (2 * np.pi), pdf_t, c='b', label=label) + +pl.plot(t / (2 * np.pi), pdf1, c="r", label="Target distribution") +pl.legend() + +mu2 = 0 +kappa2 = kappa1 + +x1 = np.random.vonmises(mu1, kappa1, size=(10,)) + np.pi +x2 = np.random.vonmises(mu2, kappa2, size=(10,)) + np.pi + +angles = np.linspace(0, 2 * np.pi, 150) + +pl.figure(2) +pl.plot(np.cos(angles), np.sin(angles), c="k") +pl.xlim(-1.25, 1.25) +pl.ylim(-1.25, 1.25) +pl.scatter(np.cos(x1), np.sin(x1), c="b") +pl.scatter(np.cos(x2), np.sin(x2), c="r") + +######################################################################################### +# Compare the Euclidean Wasserstein distance with the Wasserstein distance on the circle +# --------------------------------------------------------------------------------------- +# This examples illustrates the periodicity of the Wasserstein distance on the circle. +# We choose as target distribution a von Mises distribution with mean :math:`\mu_{\mathrm{target}}` +# and :math:`\kappa=20`. Then, we compare the distances with samples obtained from a von Mises distribution +# with parameters :math:`\mu_{\mathrm{source}}` and :math:`\kappa=20`. +# The Wasserstein distance on the circle takes into account the periodicity +# and attains its maximum in :math:`\mu_{\mathrm{target}}+1` (the antipodal point) contrary to the +# Euclidean version. + +#%% Compute and plot distributions + +mu_targets = np.linspace(0, 2 * np.pi, 200) +xs = np.random.vonmises(mu1 - np.pi, kappa1, size=(500,)) + np.pi + +n_try = 5 + +xts = np.zeros((n_try, 200, 500)) +for i in range(n_try): + for k, mu in enumerate(mu_targets): + # np.random.vonmises deals with data on [-pi, pi[ + xt = np.random.vonmises(mu - np.pi, kappa2, size=(500,)) + np.pi + xts[i, k] = xt + +# Put data on S^1=[0,1[ +xts2 = xts / (2 * np.pi) +xs2 = np.concatenate([xs[None] for k in range(200)], axis=0) / (2 * np.pi) + +L_w2_circle = np.zeros((n_try, 200)) +L_w2 = np.zeros((n_try, 200)) + +for i in range(n_try): + w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2) + w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2) + + L_w2_circle[i] = w2_circle + L_w2[i] = w2 + +m_w2_circle = np.mean(L_w2_circle, axis=0) +std_w2_circle = np.std(L_w2_circle, axis=0) + +m_w2 = np.mean(L_w2, axis=0) +std_w2 = np.std(L_w2, axis=0) + +pl.figure(1) +pl.plot(mu_targets / (2 * np.pi), m_w2_circle, label="Wasserstein circle") +pl.fill_between(mu_targets / (2 * np.pi), m_w2_circle - 2 * std_w2_circle, m_w2_circle + 2 * std_w2_circle, alpha=0.5) +pl.plot(mu_targets / (2 * np.pi), m_w2, label="Euclidean Wasserstein") +pl.fill_between(mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5) +pl.vlines(x=[mu1 / (2 * np.pi)], ymin=0, ymax=np.max(w2), linestyle="--", color="k", label=r"$\mu_{\mathrm{target}}$") +pl.legend() +pl.xlabel(r"$\mu_{\mathrm{source}}$") +pl.show() + + +######################################################################## +# Wasserstein distance between von Mises and uniform for different kappa +# ---------------------------------------------------------------------- +# When :math:`\kappa=0`, the von Mises distribution is the uniform distribution on :math:`S^1`. + +#%% Compute Wasserstein between Von Mises and uniform + +kappas = np.logspace(-5, 2, 100) +n_try = 20 + +xts = np.zeros((n_try, 100, 500)) +for i in range(n_try): + for k, kappa in enumerate(kappas): + # np.random.vonmises deals with data on [-pi, pi[ + xt = np.random.vonmises(0, kappa, size=(500,)) + np.pi + xts[i, k] = xt / (2 * np.pi) + +L_w2 = np.zeros((n_try, 100)) +for i in range(n_try): + L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T) + +m_w2 = np.mean(L_w2, axis=0) +std_w2 = np.std(L_w2, axis=0) + +pl.figure(1) +pl.plot(kappas, m_w2) +pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5) +pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$") +pl.xlabel(r"$\kappa$") +pl.show() + +# %% diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py new file mode 100644 index 0000000..83d458f --- /dev/null +++ b/examples/sliced-wasserstein/plot_variance_ssw.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +Spherical Sliced Wasserstein on distributions in S^2 +==================================================== + +This example illustrates the computation of the spherical sliced Wasserstein discrepancy as +proposed in [46]. + +[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). 'Spherical Sliced-Wasserstein". International Conference on Learning Representations. + +""" + +# Author: Clément Bonet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import matplotlib.pylab as pl +import numpy as np + +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 500 # nb samples + +xs = np.random.randn(n, 3) +xt = np.random.randn(n, 3) + +xs = xs / np.sqrt(np.sum(xs**2, -1, keepdims=True)) +xt = xt / np.sqrt(np.sum(xt**2, -1, keepdims=True)) + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +############################################################################## +# Plot data +# --------- + +# %% plot samples + +fig = pl.figure(figsize=(10, 10)) +ax = pl.axes(projection='3d') +ax.grid(False) + +u, v = np.mgrid[0:2 * np.pi:30j, 0:np.pi:30j] +x = np.cos(u) * np.sin(v) +y = np.sin(u) * np.sin(v) +z = np.cos(v) +ax.plot_surface(x, y, z, color="gray", alpha=0.03) +ax.plot_wireframe(x, y, z, linewidth=1, alpha=0.25, color="gray") + +ax.scatter(xs[:, 0], xs[:, 1], xs[:, 2], label="Source") +ax.scatter(xt[:, 0], xt[:, 1], xt[:, 2], label="Target") + +fs = 10 +# Labels +ax.set_xlabel('x', fontsize=fs) +ax.set_ylabel('y', fontsize=fs) +ax.set_zlabel('z', fontsize=fs) + +ax.view_init(20, 120) +ax.set_xlim(-1.5, 1.5) +ax.set_ylim(-1.5, 1.5) +ax.set_zlim(-1.5, 1.5) + +# Ticks +ax.set_xticks([-1, 0, 1]) +ax.set_yticks([-1, 0, 1]) +ax.set_zticks([-1, 0, 1]) + +pl.legend(loc=0) +pl.title("Source and Target distribution") + +############################################################################### +# Spherical Sliced Wasserstein for different seeds and number of projections +# -------------------------------------------------------------------------- + +n_seed = 50 +n_projections_arr = np.logspace(0, 3, 25, dtype=int) +res = np.empty((n_seed, 25)) + +# %% Compute statistics +for seed in range(n_seed): + for i, n_projections in enumerate(n_projections_arr): + res[seed, i] = ot.sliced_wasserstein_sphere(xs, xt, a, b, n_projections, seed=seed, p=1) + +res_mean = np.mean(res, axis=0) +res_std = np.std(res, axis=0) + +############################################################################### +# Plot Spherical Sliced Wasserstein +# --------------------------------- + +pl.figure(2) +pl.plot(n_projections_arr, res_mean, label=r"$SSW_1$") +pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) + +pl.legend() +pl.xscale('log') + +pl.xlabel("Number of projections") +pl.ylabel("Distance") +pl.title('Spherical Sliced Wasserstein Distance with 95% confidence inverval') + +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 0b55e0c..45d5cfa 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -38,12 +38,15 @@ from . import solvers from . import gaussian # OT functions -from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d +from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, + binary_search_circle, wasserstein_circle, + semidiscrete_wasserstein2_unif_circle) from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm -from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance +from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance, + sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif) from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport @@ -60,8 +63,10 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', - 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', + 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'factored_optimal_transport', 'solve', - 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers'] + 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', + 'binary_search_circle', 'wasserstein_circle', + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] diff --git a/ot/backend.py b/ot/backend.py index 337e040..0779243 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -534,9 +534,9 @@ class Backend(): """ raise NotImplementedError() - def zero_pad(self, a, pad_width): + def zero_pad(self, a, pad_width, value=0): r""" - Pads a tensor. + Pads a tensor with a given value (0 by default). This function follows the api from :any:`numpy.pad` @@ -895,6 +895,62 @@ class Backend(): """ raise NotImplementedError() + def tile(self, a, reps): + r""" + Construct an array by repeating a the number of times given by reps + + See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html + """ + raise NotImplementedError() + + def floor(self, a): + r""" + Return the floor of the input element-wise + + See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html + """ + raise NotImplementedError() + + def prod(self, a, axis=None): + r""" + Return the product of all elements. + + See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html + """ + raise NotImplementedError() + + def sort2(self, a, axis=None): + r""" + Return the sorted array and the indices to sort the array + + See: https://pytorch.org/docs/stable/generated/torch.sort.html + """ + raise NotImplementedError() + + def qr(self, a): + r""" + Return the QR factorization + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html + """ + raise NotImplementedError() + + def atan2(self, a, b): + r""" + Element wise arctangent + + See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html + """ + raise NotImplementedError() + + def transpose(self, a, axes=None): + r""" + Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped. + + See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1039,8 +1095,8 @@ class NumpyBackend(Backend): def concatenate(self, arrays, axis=0): return np.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return np.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return np.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return np.argmax(a, axis=axis) @@ -1185,6 +1241,44 @@ class NumpyBackend(Backend): def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return np.tile(a, reps) + + def floor(self, a): + return np.floor(a) + + def prod(self, a, axis=0): + return np.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + np_version = tuple([int(k) for k in np.__version__.split(".")]) + if np_version < (1, 22, 0): + M, N = a.shape[-2], a.shape[-1] + K = min(M, N) + + if len(a.shape) >= 3: + n = a.shape[0] + + qs, rs = np.zeros((n, M, K)), np.zeros((n, K, N)) + + for i in range(a.shape[0]): + qs[i], rs[i] = np.linalg.qr(a[i]) + + else: + return np.linalg.qr(a) + + return qs, rs + return np.linalg.qr(a) + + def atan2(self, a, b): + return np.arctan2(a, b) + + def transpose(self, a, axes=None): + return np.transpose(a, axes) + class JaxBackend(Backend): """ @@ -1351,8 +1445,8 @@ class JaxBackend(Backend): def concatenate(self, arrays, axis=0): return jnp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return jnp.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return jnp.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return jnp.argmax(a, axis=axis) @@ -1511,6 +1605,27 @@ class JaxBackend(Backend): def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return jnp.tile(a, reps) + + def floor(self, a): + return jnp.floor(a) + + def prod(self, a, axis=0): + return jnp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return jnp.linalg.qr(a) + + def atan2(self, a, b): + return jnp.arctan2(a, b) + + def transpose(self, a, axes=None): + return jnp.transpose(a, axes) + class TorchBackend(Backend): """ @@ -1729,13 +1844,13 @@ class TorchBackend(Backend): def concatenate(self, arrays, axis=0): return torch.cat(arrays, dim=axis) - def zero_pad(self, a, pad_width): + def zero_pad(self, a, pad_width, value=0): from torch.nn.functional import pad # pad_width is an array of ndim tuples indicating how many 0 before and after # we need to add. We first need to make it compliant with torch syntax, that # starts with the last dim, then second last, etc. how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl) - return pad(a, how_pad) + return pad(a, how_pad, value=value) def argmax(self, a, axis=None): return torch.argmax(a, dim=axis) @@ -1934,6 +2049,29 @@ class TorchBackend(Backend): def is_floating_point(self, a): return a.dtype.is_floating_point + def tile(self, a, reps): + return a.repeat(reps) + + def floor(self, a): + return torch.floor(a) + + def prod(self, a, axis=0): + return torch.prod(a, dim=axis) + + def sort2(self, a, axis=-1): + return torch.sort(a, axis) + + def qr(self, a): + return torch.linalg.qr(a) + + def atan2(self, a, b): + return torch.atan2(a, b) + + def transpose(self, a, axes=None): + if axes is None: + axes = tuple(range(a.ndim)[::-1]) + return a.permute(axes) + class CupyBackend(Backend): # pragma: no cover """ @@ -2096,8 +2234,8 @@ class CupyBackend(Backend): # pragma: no cover def concatenate(self, arrays, axis=0): return cp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return cp.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return cp.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return cp.argmax(a, axis=axis) @@ -2284,6 +2422,27 @@ class CupyBackend(Backend): # pragma: no cover def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return cp.tile(a, reps) + + def floor(self, a): + return cp.floor(a) + + def prod(self, a, axis=0): + return cp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return cp.linalg.qr(a) + + def atan2(self, a, b): + return cp.arctan2(a, b) + + def transpose(self, a, axes=None): + return cp.transpose(a, axes) + class TensorflowBackend(Backend): @@ -2454,8 +2613,8 @@ class TensorflowBackend(Backend): def concatenate(self, arrays, axis=0): return tnp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return tnp.pad(a, pad_width, mode="constant") + def zero_pad(self, a, pad_width, value=0): + return tnp.pad(a, pad_width, mode="constant", constant_values=value) def argmax(self, a, axis=None): return tnp.argmax(a, axis=axis) @@ -2646,3 +2805,24 @@ class TensorflowBackend(Backend): def is_floating_point(self, a): return a.dtype.is_floating + + def tile(self, a, reps): + return tnp.tile(a, reps) + + def floor(self, a): + return tf.floor(a) + + def prod(self, a, axis=0): + return tnp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return tf.linalg.qr(a) + + def atan2(self, a, b): + return tf.math.atan2(a, b) + + def transpose(self, a, axes=None): + return tf.transpose(a, perm=axes) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 17411d0..7d0640f 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -20,14 +20,17 @@ from .cvx import barycenter # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from .solver_1d import emd_1d, emd2_1d, wasserstein_1d +from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d, + binary_search_circle, wasserstein_circle, + semidiscrete_wasserstein2_unif_circle) from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter'] + 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter', + 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle'] def check_number_threads(numThreads): diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 43763a9..e7add89 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -53,7 +53,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ distributions .. math: - OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq It is formally the p-Wasserstein distance raised to the power p. We do so in a vectorized way by first building the individual quantile functions then integrating them. @@ -365,3 +365,628 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, log_emd = {'G': G} return cost, log_emd return cost + + +def roll_cols(M, shifts): + r""" + Utils functions which allow to shift the order of each row of a 2d matrix + + Parameters + ---------- + M : (nr, nc) ndarray + Matrix to shift + shifts: int or (nr,) ndarray + + Returns + ------- + Shifted array + + Examples + -------- + >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]]) + >>> roll_cols(M, 2) + array([[2, 3, 1], + [5, 6, 4], + [8, 9, 7]]) + >>> roll_cols(M, np.array([[1],[2],[1]])) + array([[3, 1, 2], + [5, 6, 4], + [9, 7, 8]]) + + References + ---------- + https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch + """ + nx = get_backend(M) + + n_rows, n_cols = M.shape + + arange1 = nx.tile(nx.reshape(nx.arange(n_cols), (1, n_cols)), (n_rows, 1)) + arange2 = (arange1 - shifts) % n_cols + + return nx.take_along_axis(M, arange2, 1) + + +def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): + r""" Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1]) + + Parameters + ---------- + theta: array-like, shape (n_batch, n) + Cuts on the circle + u_values: array-like, shape (n_batch, n) + locations of the first empirical distribution + v_values: array-like, shape (n_batch, n) + locations of the second empirical distribution + u_cdf: array-like, shape (n_batch, n) + cdf of the first empirical distribution + v_cdf: array-like, shape (n_batch, n) + cdf of the second empirical distribution + p: float, optional = 2 + Power p used for computing the Wasserstein distance + + Returns + ------- + dCp: array-like, shape (n_batch, 1) + The batched right derivative + dCm: array-like, shape (n_batch, 1) + The batched left derivative + + References + --------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) + + v_values = nx.copy(v_values) + + n = u_values.shape[-1] + m_batch, m = v_values.shape + + v_cdf_theta = v_cdf - (theta - nx.floor(theta)) + + mask_p = v_cdf_theta >= 0 + mask_n = v_cdf_theta < 0 + + v_values[mask_n] += nx.floor(theta)[mask_n] + 1 + v_values[mask_p] += nx.floor(theta)[mask_p] + + if nx.any(mask_n) and nx.any(mask_p): + v_cdf_theta[mask_n] += 1 + + v_cdf_theta2 = nx.copy(v_cdf_theta) + v_cdf_theta2[mask_n] = np.inf + shift = (-nx.argmin(v_cdf_theta2, axis=-1)) + + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) + v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) + + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + u_cdf = u_cdf.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + + # quantiles of F_u evaluated in F_v^\theta + u_index = nx.searchsorted(u_cdf, v_cdf_theta) + u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1) + + # Deal with 1 + u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1) + u_valuesm = nx.concatenate([u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1) + + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + u_cdfm = u_cdfm.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + + u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right") + u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1) + + dCp = nx.sum(nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), axis=-1) + + dCm = nx.sum(nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), axis=-1) + + return dCp.reshape(-1, 1), dCm.reshape(-1, 1) + + +def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): + r""" Computes the the cost (Equation (6.2) of [1]) + + Parameters + ---------- + theta: array-like, shape (n_batch, n) + Cuts on the circle + u_values: array-like, shape (n_batch, n) + locations of the first empirical distribution + v_values: array-like, shape (n_batch, n) + locations of the second empirical distribution + u_cdf: array-like, shape (n_batch, n) + cdf of the first empirical distribution + v_cdf: array-like, shape (n_batch, n) + cdf of the second empirical distribution + p: float, optional = 2 + Power p used for computing the Wasserstein distance + + Returns + ------- + ot_cost: array-like, shape (n_batch,) + OT cost evaluated at theta + + References + --------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) + + v_values = nx.copy(v_values) + + m_batch, m = v_values.shape + n_batch, n = u_values.shape + + v_cdf_theta = v_cdf - (theta - nx.floor(theta)) + + mask_p = v_cdf_theta >= 0 + mask_n = v_cdf_theta < 0 + + v_values[mask_n] += nx.floor(theta)[mask_n] + 1 + v_values[mask_p] += nx.floor(theta)[mask_p] + + if nx.any(mask_n) and nx.any(mask_p): + v_cdf_theta[mask_n] += 1 + + # Put negative values at the end + v_cdf_theta2 = nx.copy(v_cdf_theta) + v_cdf_theta2[mask_n] = np.inf + shift = (-nx.argmin(v_cdf_theta2, axis=-1)) + + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) + v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) + + # Compute absciss + cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1) + cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)]) + + delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1] + + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + u_cdf = u_cdf.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + cdf_axis = cdf_axis.contiguous() + + # Compute icdf + u_index = nx.searchsorted(u_cdf, cdf_axis) + u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1) + + v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) + v_index = nx.searchsorted(v_cdf_theta, cdf_axis) + v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1) + + if p == 1: + ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1) + else: + ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1) + + return ot_cost + + +def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, + Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True, + log=False): + r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates + using e.g. the atan2 function. + + .. math:: + W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + + where: + + - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v` + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + using e.g. ot.utils.get_coordinate_circle(x) + + The function runs on backend but tensorflow is not supported. + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + p : float, optional (default=1) + Power p used for computing the Wasserstein distance + Lm : int, optional + Lower bound dC + Lp : int, optional + Upper bound dC + tm: float, optional + Lower bound theta + tp: float, optional + Upper bound theta + eps: float, optional + Stopping condition + require_sort: bool, optional + If True, sort the values. + log: bool, optional + If True, returns also the optimal theta + + Returns + ------- + loss: float + Cost associated to the optimal transportation + log: dict, optional + log dictionary returned only if log==True in parameters + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> binary_search_circle(u.T, v.T, p=1) + array([0.1]) + + References + ---------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html + """ + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1], + v_values.shape[1])) + + u_values = u_values % 1 + v_values = v_values % 1 + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + u_cdf = nx.cumsum(u_weights, 0).T + v_cdf = nx.cumsum(v_weights, 0).T + + u_values = u_values.T + v_values = v_values.T + + L = max(Lm, Lp) + + tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) + tm = nx.tile(tm, (1, m)) + tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) + tp = nx.tile(tp, (1, m)) + tc = (tm + tp) / 2 + + done = nx.zeros((u_values.shape[0], m)) + + cpt = 0 + while nx.any(1 - done): + cpt += 1 + + dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) + done = ((dCp * dCm) <= 0) * 1 + + mask = ((tp - tm) < eps / L) * (1 - done) + + if nx.any(mask): + # can probably be improved by computing only relevant values + dCptp, dCmtp = derivative_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p) + dCptm, dCmtm = derivative_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p) + Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1) + Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1) + + mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) + tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0] + done[nx.prod(mask, axis=-1) > 0] = 1 + elif nx.any(1 - done): + tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0] + tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0] + tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2 + + w = ot_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) + + if log: + return w, {"optimal_theta": tc[:, 0]} + return w + + +def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True): + r"""Computes the 1-Wasserstein distance on the circle using the level median [45]. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates + using e.g. the atan2 function. + The function runs on backend but tensorflow is not supported. + + .. math:: + W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + require_sort: bool, optional + If True, sort the values. + + Returns + ------- + loss: float + Cost associated to the optimal transportation + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> wasserstein1_circle(u.T, v.T) + array([0.1]) + + References + ---------- + .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + """ + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1], + v_values.shape[1])) + + u_values = u_values % 1 + v_values = v_values % 1 + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0) + + cdf_diff = nx.cumsum(nx.take_along_axis(nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0), 0) + cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0) + + values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1) + delta = values_sorted[1:, ...] - values_sorted[:-1, ...] + weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0) + + sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5 + sum_weights[sum_weights < 0] = np.inf + inds = nx.argmin(sum_weights, axis=0) + + levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0) + + return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0) + + +def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, + Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): + r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or + the binary search algorithm proposed in [44] otherwise. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates + using e.g. the atan2 function. + + General loss returned: + + .. math:: + OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + + For p=1, [45] + + .. math:: + W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + using e.g. ot.utils.get_coordinate_circle(x) + + The function runs on backend but tensorflow is not supported. + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + p : float, optional (default=1) + Power p used for computing the Wasserstein distance + Lm : int, optional + Lower bound dC. For p>1. + Lp : int, optional + Upper bound dC. For p>1. + tm: float, optional + Lower bound theta. For p>1. + tp: float, optional + Upper bound theta. For p>1. + eps: float, optional + Stopping condition. For p>1. + require_sort: bool, optional + If True, sort the values. + + Returns + ------- + loss: float + Cost associated to the optimal transportation + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> wasserstein_circle(u.T, v.T) + array([0.1]) + + References + ---------- + .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if p == 1: + return wasserstein1_circle(u_values, v_values, u_weights, v_weights, require_sort) + + return binary_search_circle(u_values, v_values, u_weights, v_weights, + p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps, + require_sort=require_sort) + + +def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): + r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1` + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates + using e.g. the atan2 function. + + .. math:: + W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12} + + where: + + - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}` + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}, + + using e.g. ot.utils.get_coordinate_circle(x) + + Parameters + ---------- + u_values: ndarray, shape (n, ...) + Samples + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + + Returns + ------- + loss: float + Cost associated to the optimal transportation + + Examples + -------- + >>> x0 = np.array([[0], [0.2], [0.4]]) + >>> semidiscrete_wasserstein2_unif_circle(x0) + array([0.02111111]) + + References + ---------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + + if u_weights is not None: + nx = get_backend(u_values, u_weights) + else: + nx = get_backend(u_values) + + n = u_values.shape[0] + + u_values = u_values % 1 + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + u_values = nx.sort(u_values, 0) + u_cdf = nx.cumsum(u_weights, 0) + u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) + + cpt1 = nx.sum(u_weights * u_values**2, axis=0) + u_mean = nx.sum(u_weights * u_values, axis=0) + + ns = 1 - u_weights - 2 * u_cdf[:-1] + cpt2 = nx.sum(u_values * u_weights * ns, axis=0) + + return cpt1 - u_mean**2 + cpt2 + 1 / 12 diff --git a/ot/sliced.py b/ot/sliced.py index 20891a4..077ff0b 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -12,7 +12,8 @@ Sliced OT Distances import numpy as np from .backend import get_backend, NumpyBackend -from .utils import list_to_array +from .utils import list_to_array, get_coordinate_circle +from .lp import wasserstein_circle, semidiscrete_wasserstein2_unif_circle def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): @@ -107,7 +108,6 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, -------- >>> n_samples_a = 20 - >>> reg = 0.1 >>> X = np.random.normal(0., 1., (n_samples_a, 5)) >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE 0.0 @@ -208,7 +208,6 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, -------- >>> n_samples_a = 20 - >>> reg = 0.1 >>> X = np.random.normal(0., 1., (n_samples_a, 5)) >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE 0.0 @@ -258,3 +257,183 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, if log: return res, {"projections": projections, "projected_emds": projected_emd} return res + + +def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, + p=2, seed=None, log=False): + r""" + Compute the spherical sliced-Wasserstein discrepancy. + + .. math:: + SSW_p(\mu,\nu) = \left(\int_{\mathbb{V}_{d,2}} W_p^p(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right)^{\frac{1}{p}} + + where: + + - :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}` + + The function runs on backend but tensorflow is not supported. + + Parameters + ---------- + X_s: ndarray, shape (n_samples_a, dim) + Samples in the source domain + X_t: ndarray, shape (n_samples_b, dim) + Samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional (default=2) + Power p used for computing the spherical sliced Wasserstein + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_sphere returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Spherical Sliced Wasserstein Cost + log: dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + >>> n_samples_a = 20 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True)) + >>> sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + if a is not None and b is not None: + nx = get_backend(X_s, X_t, a, b) + else: + nx = get_backend(X_s, X_t) + + n, d = X_s.shape + m, _ = X_t.shape + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("X_s is not on the sphere.") + if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("Xt is not on the sphere.") + + # Uniforms and independent samples on the Stiefel manifold V_{d,2} + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + Z = seed.randn(n_projections, d, 2) + else: + if seed is not None: + nx.seed(seed) + Z = nx.randn(n_projections, d, 2, type_as=X_s) + + projections, _ = nx.qr(Z) + + # Projection on S^1 + # Projection on plane + Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) + Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1)) + + # Projection on sphere + Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) + Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True)) + + # Get coordinates on [0,1[ + Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) + Xpt_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m)) + + projected_emd = wasserstein_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p) + res = nx.mean(projected_emd) ** (1 / p) + + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res + + +def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log=False): + r"""Compute the 2-spherical sliced wasserstein w.r.t. a uniform distribution. + + .. math:: + SSW_2(\mu_n, \nu) + + where + + - :math:`\mu_n=\sum_{i=1}^n \alpha_i \delta_{x_i}` + - :math:`\nu=\mathrm{Unif}(S^1)` + + Parameters + ---------- + X_s: ndarray, shape (n_samples_a, dim) + Samples in the source domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Spherical Sliced Wasserstein Cost + log: dict, optional + log dictionary return only if log==True in parameters + + Examples + --------- + >>> np.random.seed(42) + >>> x0 = np.random.randn(500,3) + >>> x0 = x0 / np.sqrt(np.sum(x0**2, -1, keepdims=True)) + >>> ssw = sliced_wasserstein_sphere_unif(x0, seed=42) + >>> np.allclose(sliced_wasserstein_sphere_unif(x0, seed=42), 0.01734, atol=1e-3) + True + + References: + ----------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + if a is not None: + nx = get_backend(X_s, a) + else: + nx = get_backend(X_s) + + n, d = X_s.shape + + if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("X_s is not on the sphere.") + + # Uniforms and independent samples on the Stiefel manifold V_{d,2} + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + Z = seed.randn(n_projections, d, 2) + else: + if seed is not None: + nx.seed(seed) + Z = nx.randn(n_projections, d, 2, type_as=X_s) + + projections, _ = nx.qr(Z) + + # Projection on S^1 + # Projection on plane + Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) + # Projection on sphere + Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) + # Get coordinates on [0,1[ + Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) + + projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a) + res = nx.mean(projected_emd) ** (1 / 2) + + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res diff --git a/ot/utils.py b/ot/utils.py index 9093f09..3423a7e 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -375,6 +375,36 @@ def check_random_state(seed): ' instance'.format(seed)) +def get_coordinate_circle(x): + r"""For :math:`x\in S^1 \subset \mathbb{R}^2`, returns the coordinates in + turn (in [0,1[). + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + Parameters + ---------- + x: ndarray, shape (n, 2) + Samples on the circle with ambient coordinates + + Returns + ------- + x_t: ndarray, shape (n,) + Coordinates on [0,1[ + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]]) * (2 * np.pi) + >>> x1, y1 = np.cos(u), np.sin(u) + >>> x = np.concatenate([x1, y1]).T + >>> get_coordinate_circle(x) + array([0.2, 0.5, 0.8]) + """ + nx = get_backend(x) + x_t = (nx.atan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi) + return x_t + + class deprecated(object): r"""Decorator to mark a function or class as deprecated. diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 20f307a..21abd1d 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -218,3 +218,130 @@ def test_emd1d_device_tf(): nx.assert_same_dtype_device(xb, emd) nx.assert_same_dtype_device(xb, emd2) assert nx.dtype_device(emd)[1].startswith("GPU") + + +def test_wasserstein_1d_circle(): + # test binary_search_circle and wasserstein_circle give similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + + wass1 = ot.emd2(w_u, w_v, M1) + + wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1) + w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1) + + M2 = M1**2 + wass2 = ot.emd2(w_u, w_v, M2) + wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2) + w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2) + + # check loss is similar + np.testing.assert_allclose(wass1, wass1_bsc) + np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2) + np.testing.assert_allclose(wass2, wass2_bsc) + np.testing.assert_allclose(wass2, w2_circle) + + +@pytest.skip_backend("tf") +def test_wasserstein1d_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1) + w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2) + + nx.assert_same_dtype_device(xb, w1) + nx.assert_same_dtype_device(xb, w2_bsc) + + +def test_wasserstein_1d_unif_circle(): + # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle + n = 20 + m = 50000 + + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + # w_u = rng.uniform(0., 1., n) + # w_u = w_u / w_u.sum() + + w_u = ot.utils.unif(n) + w_v = ot.utils.unif(m) + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + wass2 = ot.emd2(w_u, w_v, M1**2) + + wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15) + wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u) + + # check loss is similar + np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3) + np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3) + + +def test_wasserstein1d_unif_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp) + + w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub) + + nx.assert_same_dtype_device(xb, w2) + + +def test_binary_search_circle_log(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True) + optimal_thetas = log["optimal_theta"] + + assert optimal_thetas.shape[0] == 1 + + +def test_wasserstein_circle_bad_shape(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + v = rng.rand(m, 1) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=2) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=1) diff --git a/test/test_backend.py b/test/test_backend.py index 3628f61..fd9a761 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -282,6 +282,20 @@ def test_empty_backend(): nx.array_equal(M, M) with pytest.raises(NotImplementedError): nx.is_floating_point(M) + with pytest.raises(NotImplementedError): + nx.tile(M, (10, 1)) + with pytest.raises(NotImplementedError): + nx.floor(M) + with pytest.raises(NotImplementedError): + nx.prod(M) + with pytest.raises(NotImplementedError): + nx.sort2(M) + with pytest.raises(NotImplementedError): + nx.qr(M) + with pytest.raises(NotImplementedError): + nx.atan2(v, v) + with pytest.raises(NotImplementedError): + nx.transpose(M) def test_func_backends(nx): @@ -603,6 +617,38 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("isfinite") + A = nx.tile(vb, (10, 1)) + lst_b.append(nx.to_numpy(A)) + lst_name.append("tile") + + A = nx.floor(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("floor") + + A = nx.prod(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("prod") + + A, B = nx.sort2(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("sort2 sort") + lst_b.append(nx.to_numpy(B)) + lst_name.append("sort2 argsort") + + A, B = nx.qr(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("QR Q") + lst_b.append(nx.to_numpy(B)) + lst_name.append("QR R") + + A = nx.atan2(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("atan2") + + A = nx.transpose(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("transpose") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( diff --git a/test/test_sliced.py b/test/test_sliced.py index eb13469..f54c799 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -266,3 +266,189 @@ def test_max_sliced_backend_device_tf(): valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") + + +def test_projections_stiefel(): + rng = np.random.RandomState(0) + + n_projs = 500 + x = np.random.randn(100, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + ssw, log = ot.sliced_wasserstein_sphere(x, x, n_projections=n_projs, + seed=rng, log=True) + + P = log["projections"] + P_T = np.transpose(P, [0, 2, 1]) + np.testing.assert_almost_equal(np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)])) + + +def test_sliced_sphere_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res = ot.sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_sliced_sphere_bad_shapes(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + +def test_sliced_sphere_values_on_the_sphere(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 4) + + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + +def test_sliced_sphere_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_sphere(x, y, u, u, 10, p=1, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[0] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_sphere_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + y = rng.randn(n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + res = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + assert res > 0. + + +def test_1d_sliced_sphere_equals_emd(): + n = 100 + m = 120 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + x_coords = (np.arctan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi) + a = rng.uniform(0, 1, n) + a /= a.sum() + + y = rng.randn(m, 2) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + y_coords = (np.arctan2(-y[:, 1], -y[:, 0]) + np.pi) / (2 * np.pi) + u = ot.utils.unif(m) + + res = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=2) + expected = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=2) + + res1 = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=1) + expected1 = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=1) + + np.testing.assert_almost_equal(res ** 2, expected) + np.testing.assert_almost_equal(res1, expected1, decimal=3) + + +@pytest.skip_backend("tf") +def test_sliced_sphere_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(2 * n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, yb = nx.from_numpy(x, y, type_as=tp) + + valb = ot.sliced_wasserstein_sphere(xb, yb) + + nx.assert_same_dtype_device(xb, valb) + + +def test_sliced_sphere_unif_values_on_the_sphere(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng) + + +def test_sliced_sphere_unif_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[0] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_sphere_unif_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + + valb = ot.sliced_wasserstein_sphere_unif(xb) + + nx.assert_same_dtype_device(xb, valb) diff --git a/test/test_utils.py b/test/test_utils.py index 666c157..31b12ef 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -330,3 +330,13 @@ def test_OTResult(): for at in lst_attributes: with pytest.raises(NotImplementedError): getattr(res, at) + + +def test_get_coordinate_circle(): + + u = np.random.rand(1, 100) + x1, y1 = np.cos(u * (2 * np.pi)), np.sin(u * (2 * np.pi)) + x = np.concatenate([x1, y1]).T + x_p = ot.utils.get_coordinate_circle(x) + + np.testing.assert_allclose(u[0], x_p) -- cgit v1.2.3 From a5930d3b3a446bf860d6dfacc1e17151fae1dd1d Mon Sep 17 00:00:00 2001 From: Cédric Vincent-Cuaz Date: Thu, 9 Mar 2023 14:21:33 +0100 Subject: [MRG] Semi-relaxed (fused) gromov-wasserstein divergence and improvements of gromov-wasserstein solvers (#431) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * maj gw/ srgw/ generic cg solver * correct pep8 on current state * fix bug previous tests * fix pep8 * fix bug srGW constC in loss and gradient * fix doc html * fix doc html * start updating test_optim.py * update tests gromov and optim - plus fix gromov dependencies * add symmetry feature to entropic gw * add symmetry feature to entropic gw * add exemple for sr(F)GW matchings * small stuff * remove (reg,M) from line-search/ complete srgw tests with backend * remove backend repetitions / rename fG to costG/ fix innerlog to True * fix pep8 * take comments into account / new nx parameters still to test * factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions * split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions * manual documentaion gromov * remove circular autosummary * trying stuff * debug documentation * alphabetic ordering of module * merge into branch * add note in entropic gw solvers --------- Co-authored-by: Rémi Flamary --- CONTRIBUTORS.md | 2 +- README.md | 7 +- RELEASES.md | 4 +- docs/cache_nbrun | 1 - docs/source/_templates/module.rst | 2 + docs/source/all.rst | 32 +- docs/source/conf.py | 5 +- examples/gromov/plot_gromov.py | 1 - examples/gromov/plot_semirelaxed_fgw.py | 300 ++++ ot/__init__.py | 1 - ot/bregman.py | 2 +- ot/dr.py | 2 +- ot/gromov.py | 2838 ------------------------------- ot/gromov/__init__.py | 48 + ot/gromov/_bregman.py | 351 ++++ ot/gromov/_dictionary.py | 1008 +++++++++++ ot/gromov/_estimators.py | 425 +++++ ot/gromov/_gw.py | 978 +++++++++++ ot/gromov/_semirelaxed.py | 543 ++++++ ot/gromov/_utils.py | 413 +++++ ot/lp/__init__.py | 2 +- ot/optim.py | 460 ++--- test/test_gromov.py | 621 ++++++- test/test_optim.py | 8 +- 24 files changed, 4964 insertions(+), 3090 deletions(-) delete mode 100644 docs/cache_nbrun create mode 100644 examples/gromov/plot_semirelaxed_fgw.py delete mode 100644 ot/gromov.py create mode 100644 ot/gromov/__init__.py create mode 100644 ot/gromov/_bregman.py create mode 100644 ot/gromov/_dictionary.py create mode 100644 ot/gromov/_estimators.py create mode 100644 ot/gromov/_gw.py create mode 100644 ot/gromov/_semirelaxed.py create mode 100644 ot/gromov/_utils.py (limited to 'ot/lp/__init__.py') diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 1437821..6b35653 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -36,7 +36,7 @@ The contributors to this library are: * [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) -* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) +* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, semi-relaxed FGW) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) diff --git a/README.md b/README.md index d5e6854..e7241b8 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ POT provides the following generic OT solvers (links to examples): * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45] * [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] * [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. +* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) [48]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. POT provides the following Machine Learning related solvers: @@ -300,4 +301,8 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. -[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. \ No newline at end of file +[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. + +[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787. + +[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index bf2ce2e..b51409b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -3,7 +3,9 @@ ## 0.8.3dev #### New features - +- Added feature to (Fused) Gromov-Wasserstein solvers herited from `ot.optim` to support relative and absolute loss variations as stopping criterions (PR #431) +- Added feature to (Fused) Gromov-Wasserstein solvers to handle asymmetric matrices (PR #431) +- Added semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #431) - Added the spherical sliced-Wasserstein discrepancy in `ot.sliced.sliced_wasserstein_sphere` and `ot.sliced.sliced_wasserstein_sphere_unif` + examples (PR #434) - Added the Wasserstein distance on the circle in ``ot.lp.solver_1d.wasserstein_circle`` (PR #434) - Added the Wasserstein distance on the circle (for p>=1) in `ot.lp.solver_1d.binary_search_circle` + examples (PR #434) diff --git a/docs/cache_nbrun b/docs/cache_nbrun deleted file mode 100644 index ac49515..0000000 --- a/docs/cache_nbrun +++ /dev/null @@ -1 +0,0 @@ -{"plot_otda_color_images.ipynb": "128d0435c08ebcf788913e4adcd7dd00", "plot_partial_wass_and_gromov.ipynb": "82242f8390df1d04806b333b745c72cf", "plot_WDA.ipynb": "27f8de4c6d7db46497076523673eedfb", "plot_screenkhorn_1D.ipynb": "af7b8a74a1be0f16f2c3908f5a178de0", "plot_otda_laplacian.ipynb": "d92cc0e528b9277f550daaa6f9d18415", "plot_OT_L1_vs_L2.ipynb": "288230c4e679d752a511353c96c134cb", "plot_otda_semi_supervised.ipynb": "568b39ffbdf6621dd6de162df42f4f21", "plot_fgw.ipynb": "f4de8e6939ce2b1339b3badc1fef0f37", "plot_otda_d2.ipynb": "07ef3212ff3123f16c32a5670e0167f8", "plot_compute_emd.ipynb": "299f6fffcdbf48b7c3268c0136e284f8", "plot_barycenter_fgw.ipynb": "9e813d3b07b7c0c0fcc35a778ca1243f", "plot_convolutional_barycenter.ipynb": "fdd259bfcd6d5fe8001efb4345795d2f", "plot_optim_OTreg.ipynb": "bddd8e49f092873d8980d41ae4974e19", "plot_UOT_1D.ipynb": "2658d5164165941b07539dae3cb80a0f", "plot_OT_1D_smooth.ipynb": "f3e1f0e362c9a78071a40c02b85d2305", "plot_barycenter_1D.ipynb": "f6fa5bc13d9811f09792f73b4de70aa0", "plot_otda_mapping.ipynb": "1bb321763f670fc945d77cfc91471e5e", "plot_OT_1D.ipynb": "0346a8c862606d11f36d0aa087ecab0d", "plot_gromov_barycenter.ipynb": "a7999fcc236d90a0adeb8da2c6370db3", "plot_UOT_barycenter_1D.ipynb": "dd9b857a8c66d71d0124d4a2c30a51dd", "plot_otda_mapping_colors_images.ipynb": "16faae80d6ea8b37d6b1f702149a10de", "plot_stochastic.ipynb": "64f23a8dcbab9823ae92f0fd6c3aceab", "plot_otda_linear_mapping.ipynb": "82417d9141e310bf1f2c2ecdb550094b", "plot_otda_classes.ipynb": "8836a924c9b562ef397af12034fa1abb", "plot_free_support_barycenter.ipynb": "be9d0823f9d7774a289311b9f14548eb", "plot_gromov.ipynb": "de06b1dbe8de99abae51c2e0b64b485d", "plot_otda_jcpot.ipynb": "65482cbfef5c6c1e5e73998aeb5f4b10", "plot_OT_2D_samples.ipynb": "9a9496792fa4216b1059fc70abca851a", "plot_barycenter_lp_vs_entropic.ipynb": "334840b69a86898813e50a6db0f3d0de"} \ No newline at end of file diff --git a/docs/source/_templates/module.rst b/docs/source/_templates/module.rst index 5ad89be..495995e 100644 --- a/docs/source/_templates/module.rst +++ b/docs/source/_templates/module.rst @@ -2,6 +2,7 @@ {{ underline }} .. automodule:: {{ fullname }} + :members: {% block functions %} {% if functions %} @@ -12,6 +13,7 @@ {% for item in functions %} .. autofunction:: {{ item }} + .. include:: backreferences/{{fullname}}.{{item}}.examples diff --git a/docs/source/all.rst b/docs/source/all.rst index 60cc85c..41d8e06 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -13,29 +13,33 @@ API and modules :toctree: gen_modules/ :template: module.rst - lp + backend bregman - smooth - gromov - optim da - dr - utils datasets + dr + factored + gaussian + gromov + lp + optim + partial plot - stochastic - unbalanced regpath - partial sliced + smooth + stochastic + unbalanced + utils weak - factored - gaussian + -.. autosummary:: - :toctree: ../modules/generated/ - :template: module.rst +Main :py:mod:`ot` functions +-------------- .. automodule:: ot :members: + + + diff --git a/docs/source/conf.py b/docs/source/conf.py index 3bec150..6e76291 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -119,7 +119,7 @@ release = __version__ # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = "en" # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: @@ -341,6 +341,9 @@ texinfo_documents = [ # If true, do not generate a @detailmenu in the "Top" node's menu. #texinfo_no_detailmenu = False +autodoc_default_options = {'autosummary': True, + 'autosummary_imported_members': True} + # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = {'python': ('https://docs.python.org/3', None), diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index 5a362cf..05074dc 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -3,7 +3,6 @@ ========================== Gromov-Wasserstein example ========================== - This example is designed to show how to use the Gromov-Wassertsein distance computation in POT. """ diff --git a/examples/gromov/plot_semirelaxed_fgw.py b/examples/gromov/plot_semirelaxed_fgw.py new file mode 100644 index 0000000..8f879d4 --- /dev/null +++ b/examples/gromov/plot_semirelaxed_fgw.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- +""" +========================== +Semi-relaxed (Fused) Gromov-Wasserstein example +========================== + +This example is designed to show how to use the semi-relaxed Gromov-Wasserstein +and the semi-relaxed Fused Gromov-Wasserstein divergences. + +sr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of +G2 at a minimal (F)GW distance from G1. + +First, we generate two graphs following Stochastic Block Models, then show +how to compute their srGW matchings and illustrate them. These graphs are then +endowed with node features and we follow the same process with srFGW. + +[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. +"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" +International Conference on Learning Representations (ICLR), 2021. +""" + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +import numpy as np +import matplotlib.pylab as pl +from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein +import networkx +from networkx.generators.community import stochastic_block_model as sbm + +# %% +# ============================================================================= +# Generate two graphs following Stochastic Block models of 2 and 3 clusters. +# ============================================================================= + + +N2 = 20 # 2 communities +N3 = 30 # 3 communities +p2 = [[1., 0.1], + [0.1, 0.9]] +p3 = [[1., 0.1, 0.], + [0.1, 0.95, 0.1], + [0., 0.1, 0.9]] +G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2) +G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3) + + +C2 = networkx.to_numpy_array(G2) +C3 = networkx.to_numpy_array(G3) + +h2 = np.ones(C2.shape[0]) / C2.shape[0] +h3 = np.ones(C3.shape[0]) / C3.shape[0] + +# Add weights on the edges for visualization later on +weight_intra_G2 = 5 +weight_inter_G2 = 0.5 +weight_intra_G3 = 1. +weight_inter_G3 = 1.5 + +weightedG2 = networkx.Graph() +part_G2 = [G2.nodes[i]['block'] for i in range(N2)] + +for node in G2.nodes(): + weightedG2.add_node(node) +for i, j in G2.edges(): + if part_G2[i] == part_G2[j]: + weightedG2.add_edge(i, j, weight=weight_intra_G2) + else: + weightedG2.add_edge(i, j, weight=weight_inter_G2) + +weightedG3 = networkx.Graph() +part_G3 = [G3.nodes[i]['block'] for i in range(N3)] + +for node in G3.nodes(): + weightedG3.add_node(node) +for i, j in G3.edges(): + if part_G3[i] == part_G3[j]: + weightedG3.add_edge(i, j, weight=weight_intra_G3) + else: + weightedG3.add_edge(i, j, weight=weight_inter_G3) +# %% +# ============================================================================= +# Compute their semi-relaxed Gromov-Wasserstein divergences +# ============================================================================= + +# 0) GW(C2, h2, C3, h3) for reference +OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True) +gw = log['gw_dist'] + +# 1) srGW(C2, h2, C3) +OT_23, log_23 = semirelaxed_gromov_wasserstein(C2, C3, h2, symmetric=True, + log=True, G0=None) +srgw_23 = log_23['srgw_dist'] + +# 2) srGW(C3, h3, C2) + +OT_32, log_32 = semirelaxed_gromov_wasserstein(C3, C2, h3, symmetric=None, + log=True, G0=OT.T) +srgw_32 = log_32['srgw_dist'] + +print('GW(C2, C3) = ', gw) +print('srGW(C2, h2, C3) = ', srgw_23) +print('srGW(C3, h3, C2) = ', srgw_32) + + +# %% +# ============================================================================= +# Visualization of the semi-relaxed Gromov-Wasserstein matchings +# ============================================================================= + +# We color nodes of the graph on the right - then project its node colors +# based on the optimal transport plan from the srGW matching + + +def draw_graph(G, C, nodes_color_part, Gweights=None, + pos=None, edge_color='black', node_size=None, + shiftx=0, seed=0): + + if (pos is None): + pos = networkx.spring_layout(G, scale=1., seed=seed) + + if shiftx != 0: + for k, v in pos.items(): + v[0] = v[0] + shiftx + + alpha_edge = 0.7 + width_edge = 1.8 + if Gweights is None: + networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color) + else: + # We make more visible connections between activated nodes + n = len(Gweights) + edgelist_activated = [] + edgelist_deactivated = [] + for i in range(n): + for j in range(n): + if Gweights[i] * Gweights[j] * C[i, j] > 0: + edgelist_activated.append((i, j)) + elif C[i, j] > 0: + edgelist_deactivated.append((i, j)) + + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated, + width=width_edge, alpha=alpha_edge, + edge_color=edge_color) + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated, + width=width_edge, alpha=0.1, + edge_color=edge_color) + + if Gweights is None: + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=node_size, alpha=1, + node_color=node_color) + else: + scaled_Gweights = Gweights / (0.5 * Gweights.max()) + nodes_size = node_size * scaled_Gweights + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=nodes_size[node], alpha=1, + node_color=node_color) + return pos + + +def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, + p1, p2, T, pos1=None, pos2=None, + shiftx=4, switchx=False, node_size=70, + seed_G1=0, seed_G2=0): + starting_color = 0 + # get graphs partition and their coloring + part1 = part_G1.copy() + unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)] + nodes_color_part1 = [] + for cluster in part1: + nodes_color_part1.append(unique_colors[cluster]) + + nodes_color_part2 = [] + # T: getting colors assignment from argmin of columns + for i in range(len(G2.nodes())): + j = np.argmax(T[:, i]) + nodes_color_part2.append(nodes_color_part1[j]) + pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1, + pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1) + pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2, + node_size=node_size, shiftx=shiftx, seed=seed_G2) + for k1, v1 in pos1.items(): + for k2, v2 in pos2.items(): + if (T[k1, k2] > 0): + pl.plot([pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + '-', lw=0.8, alpha=0.5, + color=nodes_color_part1[k1]) + return pos1, pos2 + + +node_size = 40 +fontsize = 10 +seed_G2 = 0 +seed_G3 = 4 + +pl.figure(1, figsize=(8, 2.5)) +pl.clf() +pl.subplot(121) +pl.axis('off') +pl.axis +pl.title(r'srGW$(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize) + +hbar2 = OT_23.sum(axis=0) +pos1, pos2 = draw_transp_colored_srGW( + weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) +pl.subplot(122) +pl.axis('off') +hbar3 = OT_32.sum(axis=0) +pl.title(r'srGW$(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize) +pos1, pos2 = draw_transp_colored_srGW( + weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, + pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) +pl.tight_layout() + +pl.show() + +# %% +# ============================================================================= +# Add node features +# ============================================================================= + +# We add node features with given mean - by clusters +# and inversely proportional to clusters' intra-connectivity + +F2 = np.zeros((N2, 1)) +for i, c in enumerate(part_G2): + F2[i, 0] = np.random.normal(loc=c, scale=0.01) + +F3 = np.zeros((N3, 1)) +for i, c in enumerate(part_G3): + F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01) + +# %% +# ============================================================================= +# Compute their semi-relaxed Fused Gromov-Wasserstein divergences +# ============================================================================= + +alpha = 0.5 +# Compute pairwise euclidean distance between node features +M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T) + +# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference + +OT, log = fused_gromov_wasserstein( + M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True) +fgw = log['fgw_dist'] + +# 1) srFGW(C2, F2, h2, C3, F3) +OT_23, log_23 = semirelaxed_fused_gromov_wasserstein( + M, C2, C3, h2, symmetric=True, alpha=0.5, log=True, G0=None) +srfgw_23 = log_23['srfgw_dist'] + +# 2) srFGW(C3, F3, h3, C2, F2) + +OT_32, log_32 = semirelaxed_fused_gromov_wasserstein( + M.T, C3, C2, h3, symmetric=None, alpha=alpha, log=True, G0=None) +srfgw_32 = log_32['srfgw_dist'] + +print('FGW(C2, F2, C3, F3) = ', fgw) +print('srGW(C2, F2, h2, C3, F3) = ', srfgw_23) +print('srGW(C3, F3, h3, C2, F2) = ', srfgw_32) + +# %% +# ============================================================================= +# Visualization of the semi-relaxed Fused Gromov-Wasserstein matchings +# ============================================================================= + +# We color nodes of the graph on the right - then project its node colors +# based on the optimal transport plan from the srFGW matching +# NB: colors refer to clusters - not to node features + +pl.figure(2, figsize=(8, 2.5)) +pl.clf() +pl.subplot(121) +pl.axis('off') +pl.axis +pl.title(r'srFGW$(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize) + +hbar2 = OT_23.sum(axis=0) +pos1, pos2 = draw_transp_colored_srGW( + weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) +pl.subplot(122) +pl.axis('off') +hbar3 = OT_32.sum(axis=0) +pl.title(r'srFGW$(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize) +pos1, pos2 = draw_transp_colored_srGW( + weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, + pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) +pl.tight_layout() + +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 45d5cfa..0e36459 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -8,7 +8,6 @@ , :py:mod:`ot.unbalanced`. The following sub-modules are not imported due to additional dependencies: - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`. - - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU. - :any:`ot.plot` : depends on :code:`matplotlib` """ diff --git a/ot/bregman.py b/ot/bregman.py index 192a9e2..20bef7e 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -3048,7 +3048,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', M = nx.from_numpy(M, type_as=a) m1_cols.append( nx.sum(nx.exp(f[i:i + bs, None] + - g[None, :] - M / reg), axis=1) + g[None, :] - M / reg), axis=1) ) m1 = nx.concatenate(m1_cols, axis=0) err = nx.sum(nx.abs(m1 - a)) diff --git a/ot/dr.py b/ot/dr.py index 1b97841..b92cd14 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -167,7 +167,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter Size of dimensionnality reduction. reg : float, optional Regularization term >0 (entropic regularization) - solver : None | str, optional + solver : None | str, optional None for steepest descent or 'TrustRegions' for trust regions algorithm else should be a pymanopt.solvers sinkhorn_method : str diff --git a/ot/gromov.py b/ot/gromov.py deleted file mode 100644 index bc1c8e5..0000000 --- a/ot/gromov.py +++ /dev/null @@ -1,2838 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers -""" - -# Author: Erwan Vautier -# Nicolas Courty -# Rémi Flamary -# Titouan Vayer -# Cédric Vincent-Cuaz -# -# License: MIT License - -import numpy as np - - -from .bregman import sinkhorn -from .utils import dist, UndefinedParameter, list_to_array -from .optim import cg -from .lp import emd_1d, emd -from .utils import check_random_state, unif -from .backend import get_backend - - -def init_matrix(C1, C2, p, q, loss_fun='square_loss'): - r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation - - Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the - selected loss function as the loss function of Gromow-Wasserstein discrepancy. - - The matrices are computed as described in Proposition 1 in :ref:`[12] ` - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{T}`: A coupling between those two spaces - - The square-loss function :math:`L(a, b) = |a - b|^2` is read as : - - .. math:: - - L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) - - \mathrm{with} \ f_1(a) &= a^2 - - f_2(b) &= b^2 - - h_1(a) &= a - - h_2(b) &= 2b - - The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as : - - .. math:: - - L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) - - \mathrm{with} \ f_1(a) &= a \log(a) - a - - f_2(b) &= b - - h_1(a) &= a - - h_2(b) &= \log(b) - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Probability distribution in the source space - q : array-like, shape (nt,) - Probability distribution in the target space - loss_fun : str, optional - Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss') - - Returns - ------- - constC : array-like, shape (ns, nt) - Constant :math:`\mathbf{C}` matrix in Eq. (6) - hC1 : array-like, shape (ns, ns) - :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) - hC2 : array-like, shape (nt, nt) - :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) - - - .. _references-init-matrix: - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - C1, C2, p, q = list_to_array(C1, C2, p, q) - nx = get_backend(C1, C2, p, q) - - if loss_fun == 'square_loss': - def f1(a): - return (a**2) - - def f2(b): - return (b**2) - - def h1(a): - return a - - def h2(b): - return 2 * b - elif loss_fun == 'kl_loss': - def f1(a): - return a * nx.log(a + 1e-15) - a - - def f2(b): - return b - - def h1(a): - return a - - def h2(b): - return nx.log(b + 1e-15) - - constC1 = nx.dot( - nx.dot(f1(C1), nx.reshape(p, (-1, 1))), - nx.ones((1, len(q)), type_as=q) - ) - constC2 = nx.dot( - nx.ones((len(p), 1), type_as=p), - nx.dot(nx.reshape(q, (1, -1)), f2(C2).T) - ) - constC = constC1 + constC2 - hC1 = h1(C1) - hC2 = h2(C2) - - return constC, hC1, hC2 - - -def tensor_product(constC, hC1, hC2, T): - r"""Return the tensor for Gromov-Wasserstein fast computation - - The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` - - Parameters - ---------- - constC : array-like, shape (ns, nt) - Constant :math:`\mathbf{C}` matrix in Eq. (6) - hC1 : array-like, shape (ns, ns) - :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) - hC2 : array-like, shape (nt, nt) - :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) - - Returns - ------- - tens : array-like, shape (`ns`, `nt`) - :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result - - - .. _references-tensor-product: - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T) - nx = get_backend(constC, hC1, hC2, T) - - A = - nx.dot( - nx.dot(hC1, T), hC2.T - ) - tens = constC + A - # tens -= tens.min() - return tens - - -def gwloss(constC, hC1, hC2, T): - r"""Return the Loss for Gromov-Wasserstein - - The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` - - Parameters - ---------- - constC : array-like, shape (ns, nt) - Constant :math:`\mathbf{C}` matrix in Eq. (6) - hC1 : array-like, shape (ns, ns) - :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) - hC2 : array-like, shape (nt, nt) - :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) - T : array-like, shape (ns, nt) - Current value of transport matrix :math:`\mathbf{T}` - - Returns - ------- - loss : float - Gromov Wasserstein loss - - - .. _references-gwloss: - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - - tens = tensor_product(constC, hC1, hC2, T) - - tens, T = list_to_array(tens, T) - nx = get_backend(tens, T) - - return nx.sum(tens * T) - - -def gwggrad(constC, hC1, hC2, T): - r"""Return the gradient for Gromov-Wasserstein - - The gradient is computed as described in Proposition 2 in :ref:`[12] ` - - Parameters - ---------- - constC : array-like, shape (ns, nt) - Constant :math:`\mathbf{C}` matrix in Eq. (6) - hC1 : array-like, shape (ns, ns) - :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) - hC2 : array-like, shape (nt, nt) - :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) - T : array-like, shape (ns, nt) - Current value of transport matrix :math:`\mathbf{T}` - - Returns - ------- - grad : array-like, shape (`ns`, `nt`) - Gromov Wasserstein gradient - - - .. _references-gwggrad: - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - return 2 * tensor_product(constC, hC1, hC2, - T) # [12] Prop. 2 misses a 2 factor - - -def update_square_loss(p, lambdas, T, Cs): - r""" - Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` - couplings calculated at each iteration - - Parameters - ---------- - p : array-like, shape (N,) - Masses in the targeted barycenter. - lambdas : list of float - List of the `S` spaces' weights. - T : list of S array-like of shape (ns,N) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. - Cs : list of S array-like, shape(ns,ns) - Metric cost matrices. - - Returns - ---------- - C : array-like, shape (`nt`, `nt`) - Updated :math:`\mathbf{C}` matrix. - """ - T = list_to_array(*T) - Cs = list_to_array(*Cs) - p = list_to_array(p) - nx = get_backend(p, *T, *Cs) - - tmpsum = sum([ - lambdas[s] * nx.dot( - nx.dot(T[s].T, Cs[s]), - T[s] - ) for s in range(len(T)) - ]) - ppt = nx.outer(p, p) - - return tmpsum / ppt - - -def update_kl_loss(p, lambdas, T, Cs): - r""" - Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration - - - Parameters - ---------- - p : array-like, shape (N,) - Weights in the targeted barycenter. - lambdas : list of float - List of the `S` spaces' weights - T : list of S array-like of shape (ns,N) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. - Cs : list of S array-like, shape(ns,ns) - Metric cost matrices. - - Returns - ---------- - C : array-like, shape (`ns`, `ns`) - updated :math:`\mathbf{C}` matrix - """ - Cs = list_to_array(*Cs) - T = list_to_array(*T) - p = list_to_array(p) - nx = get_backend(p, *T, *Cs) - - tmpsum = sum([ - lambdas[s] * nx.dot( - nx.dot(T[s].T, Cs[s]), - T[s] - ) for s in range(len(T)) - ]) - ppt = nx.outer(p, p) - - return nx.exp(tmpsum / ppt) - - -def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs): - r""" - Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` - - The function solves the following optimization problem: - - .. math:: - \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{p}`: distribution in the source space - - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : str - loss function used for the solver either 'square_loss' or 'kl_loss' - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - armijo : bool, optional - If True the step of the line-search is found via an armijo research. Else closed form is used. - If there are convergence issues use False. - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. - **kwargs : dict - parameters can be directly passed to the ot.optim.cg solver - - Returns - ------- - T : array-like, shape (`ns`, `nt`) - Coupling between the two spaces that minimizes: - - :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}` - log : dict - Convergence information and loss. - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the - metric approach to object matching. Foundations of computational - mathematics 11.4 (2011): 417-487. - - """ - p, q = list_to_array(p, q) - p0, q0, C10, C20 = p, q, C1, C2 - if G0 is None: - nx = get_backend(p0, q0, C10, C20) - else: - G0_ = G0 - nx = get_backend(p0, q0, C10, C20, G0_) - p = nx.to_numpy(p) - q = nx.to_numpy(q) - C1 = nx.to_numpy(C10) - C2 = nx.to_numpy(C20) - - if G0 is None: - G0 = p[:, None] * q[None, :] - else: - G0 = nx.to_numpy(G0_) - # Check marginals of G0 - np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) - np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) - - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - - def f(G): - return gwloss(constC, hC1, hC2, G) - - def df(G): - return gwggrad(constC, hC1, hC2, G) - - if log: - res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10) - log['u'] = nx.from_numpy(log['u'], type_as=C10) - log['v'] = nx.from_numpy(log['v'], type_as=C10) - return nx.from_numpy(res, type_as=C10), log - else: - return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10) - - -def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs): - r""" - Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` - - The function solves the following optimization problem: - - .. math:: - GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{p}`: distribution in the source space - - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity - matrices - - Note that when using backends, this loss function is differentiable wrt the - marices and weights for quadratic loss using the gradients from [38]_. - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space. - q : array-like, shape (nt,) - Distribution in the target space. - loss_fun : str - loss function used for the solver either 'square_loss' or 'kl_loss' - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - armijo : bool, optional - If True the step of the line-search is found via an armijo research. Else closed form is used. - If there are convergence issues use False. - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. - - Returns - ------- - gw_dist : float - Gromov-Wasserstein distance - log : dict - convergence information and Coupling marix - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the - metric approach to object matching. Foundations of computational - mathematics 11.4 (2011): 417-487. - - .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online - Graph Dictionary Learning, International Conference on Machine Learning - (ICML), 2021. - - """ - p, q = list_to_array(p, q) - p0, q0, C10, C20 = p, q, C1, C2 - if G0 is None: - nx = get_backend(p0, q0, C10, C20) - else: - G0_ = G0 - nx = get_backend(p0, q0, C10, C20, G0_) - - p = nx.to_numpy(p) - q = nx.to_numpy(q) - C1 = nx.to_numpy(C10) - C2 = nx.to_numpy(C20) - - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - - if G0 is None: - G0 = p[:, None] * q[None, :] - else: - G0 = nx.to_numpy(G0_) - # Check marginals of G0 - np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) - np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) - - def f(G): - return gwloss(constC, hC1, hC2, G) - - def df(G): - return gwggrad(constC, hC1, hC2, G) - - T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - - T0 = nx.from_numpy(T, type_as=C10) - - log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10) - log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10) - log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10) - log_gw['T'] = T0 - - gw = log_gw['gw_dist'] - - if loss_fun == 'square_loss': - gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) - gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) - gC1 = nx.from_numpy(gC1, type_as=C10) - gC2 = nx.from_numpy(gC2, type_as=C10) - gw = nx.set_gradients(gw, (p0, q0, C10, C20), - (log_gw['u'] - nx.mean(log_gw['u']), - log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2)) - - if log: - return gw, log_gw - else: - return gw - - -def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs): - r""" - Computes the FGW transport between two graphs (see :ref:`[24] `) - - .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} - - \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} - - \mathbf{\gamma} &\geq 0 - - where : - - - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - - `L` is a loss function to account for the misfit between the similarity matrices - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` - - Parameters - ---------- - M : array-like, shape (ns, nt) - Metric cost matrix between features across domains - C1 : array-like, shape (ns, ns) - Metric cost matrix representative of the structure in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix representative of the structure in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : str, optional - Loss function used for the solver - alpha : float, optional - Trade-off parameter (0 < alpha < 1) - armijo : bool, optional - If True the step of the line-search is found via an armijo research. Else closed form is used. - If there are convergence issues use False. - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. - log : bool, optional - record log if True - **kwargs : dict - parameters can be directly passed to the ot.optim.cg solver - - Returns - ------- - gamma : array-like, shape (`ns`, `nt`) - Optimal transportation matrix for the given parameters. - log : dict - Log dictionary return only if log==True in parameters. - - - .. _references-fused-gromov-wasserstein: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain - and Courty Nicolas "Optimal Transport for structured data with - application on graphs", International Conference on Machine Learning - (ICML). 2019. - """ - p, q = list_to_array(p, q) - p0, q0, C10, C20, M0 = p, q, C1, C2, M - if G0 is None: - nx = get_backend(p0, q0, C10, C20, M0) - else: - G0_ = G0 - nx = get_backend(p0, q0, C10, C20, M0, G0_) - - p = nx.to_numpy(p) - q = nx.to_numpy(q) - C1 = nx.to_numpy(C10) - C2 = nx.to_numpy(C20) - M = nx.to_numpy(M0) - if G0 is None: - G0 = p[:, None] * q[None, :] - else: - G0 = nx.to_numpy(G0_) - # Check marginals of G0 - np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) - np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) - - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - - def f(G): - return gwloss(constC, hC1, hC2, G) - - def df(G): - return gwggrad(constC, hC1, hC2, G) - - if log: - res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) - fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10) - log['fgw_dist'] = fgw_dist - log['u'] = nx.from_numpy(log['u'], type_as=C10) - log['v'] = nx.from_numpy(log['v'], type_as=C10) - return nx.from_numpy(res, type_as=C10), log - else: - return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10) - - -def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs): - r""" - Computes the FGW distance between two graphs see (see :ref:`[24] `) - - .. math:: - \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} - - \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} - - \mathbf{\gamma} &\geq 0 - - where : - - - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - - `L` is a loss function to account for the misfit between the similarity matrices - - The algorithm used for solving the problem is conditional gradient as - discussed in :ref:`[24] ` - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - Note that when using backends, this loss function is differentiable wrt the - marices and weights for quadratic loss using the gradients from [38]_. - - Parameters - ---------- - M : array-like, shape (ns, nt) - Metric cost matrix between features across domains - C1 : array-like, shape (ns, ns) - Metric cost matrix representative of the structure in the source space. - C2 : array-like, shape (nt, nt) - Metric cost matrix representative of the structure in the target space. - p : array-like, shape (ns,) - Distribution in the source space. - q : array-like, shape (nt,) - Distribution in the target space. - loss_fun : str, optional - Loss function used for the solver. - alpha : float, optional - Trade-off parameter (0 < alpha < 1) - armijo : bool, optional - If True the step of the line-search is found via an armijo research. - Else closed form is used. If there are convergence issues use False. - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. - log : bool, optional - Record log if True. - **kwargs : dict - Parameters can be directly passed to the ot.optim.cg solver. - - Returns - ------- - fgw-distance : float - Fused gromov wasserstein distance for the given parameters. - log : dict - Log dictionary return only if log==True in parameters. - - - .. _references-fused-gromov-wasserstein2: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain - and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - - .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online - Graph Dictionary Learning, International Conference on Machine Learning - (ICML), 2021. - """ - p, q = list_to_array(p, q) - - p0, q0, C10, C20, M0 = p, q, C1, C2, M - if G0 is None: - nx = get_backend(p0, q0, C10, C20, M0) - else: - G0_ = G0 - nx = get_backend(p0, q0, C10, C20, M0, G0_) - - p = nx.to_numpy(p) - q = nx.to_numpy(q) - C1 = nx.to_numpy(C10) - C2 = nx.to_numpy(C20) - M = nx.to_numpy(M0) - - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - - if G0 is None: - G0 = p[:, None] * q[None, :] - else: - G0 = nx.to_numpy(G0_) - # Check marginals of G0 - np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) - np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) - - def f(G): - return gwloss(constC, hC1, hC2, G) - - def df(G): - return gwggrad(constC, hC1, hC2, G) - - T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) - - fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10) - - T0 = nx.from_numpy(T, type_as=C10) - - log_fgw['fgw_dist'] = fgw_dist - log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10) - log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10) - log_fgw['T'] = T0 - - if loss_fun == 'square_loss': - gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) - gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) - gC1 = nx.from_numpy(gC1, type_as=C10) - gC2 = nx.from_numpy(gC2, type_as=C10) - fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), - (log_fgw['u'] - nx.mean(log_fgw['u']), - log_fgw['v'] - nx.mean(log_fgw['v']), - alpha * gC1, alpha * gC2, (1 - alpha) * T0)) - - if log: - return fgw_dist, log_fgw - else: - return fgw_dist - - -def GW_distance_estimation(C1, C2, p, q, loss_fun, T, - nb_samples_p=None, nb_samples_q=None, std=True, random_state=None): - r""" - Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` - with a fixed transport plan :math:`\mathbf{T}`. - - The function gives an unbiased approximation of the following equation: - - .. math:: - - GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - `L` : Loss function to account for the misfit between the similarity matrices - - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}` - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` - Loss function used for the distance, the transport plan does not depend on the loss function - T : csr or array-like, shape (ns, nt) - Transport plan matrix, either a sparse csr or a dense matrix - nb_samples_p : int, optional - `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}` - nb_samples_q : int, optional - `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first - std : bool, optional - Standard deviation associated with the prediction of the gromov-wasserstein cost - random_state : int or RandomState instance, optional - Fix the seed for reproducibility - - Returns - ------- - : float - Gromov-wasserstein cost - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ - C1, C2, p, q = list_to_array(C1, C2, p, q) - nx = get_backend(C1, C2, p, q) - - generator = check_random_state(random_state) - - len_p = p.shape[0] - len_q = q.shape[0] - - # It is always better to sample from the biggest distribution first. - if len_p < len_q: - p, q = q, p - len_p, len_q = len_q, len_p - C1, C2 = C2, C1 - T = T.T - - if nb_samples_p is None: - if nx.issparse(T): - # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced - nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p) - else: - nb_samples_p = len_p - else: - # The number of sample along the first dimension is without replacement. - nb_samples_p = min(nb_samples_p, len_p) - if nb_samples_q is None: - nb_samples_q = 1 - if std: - nb_samples_q = max(2, nb_samples_q) - - index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int) - index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) - - index_i = generator.choice( - len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False - ) - index_j = generator.choice( - len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False - ) - - for i in range(nb_samples_p): - if nx.issparse(T): - T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,)) - T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,)) - else: - T_indexi = T[index_i[i], :] - T_indexj = T[index_j[i], :] - # For each of the row sampled, the column is sampled. - index_k[i] = generator.choice( - len_q, - size=nb_samples_q, - p=nx.to_numpy(T_indexi / nx.sum(T_indexi)), - replace=True - ) - index_l[i] = generator.choice( - len_q, - size=nb_samples_q, - p=nx.to_numpy(T_indexj / nx.sum(T_indexj)), - replace=True - ) - - list_value_sample = nx.stack([ - loss_fun( - C1[np.ix_(index_i, index_j)], - C2[np.ix_(index_k[:, n], index_l[:, n])] - ) for n in range(nb_samples_q) - ], axis=2) - - if std: - std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5 - return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) - else: - return nx.mean(list_value_sample) - - -def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, - alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None): - r""" - Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe. - This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations. - - The function solves the following optimization problem: - - .. math:: - \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - - \mathbf{T}^T \mathbf{1} &= \mathbf{q} - - \mathbf{T} &\geq 0 - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{p}`: distribution in the source space - - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` - Loss function used for the distance, the transport plan does not depend on the loss function - alpha : float - Step of the Frank-Wolfe algorithm, should be between 0 and 1 - max_iter : int, optional - Max number of iterations - threshold_plan : float, optional - Deleting very small values in the transport plan. If above zero, it violates the marginal constraints. - verbose : bool, optional - Print information along iterations - log : bool, optional - Gives the distance estimated and the standard deviation - random_state : int or RandomState instance, optional - Fix the seed for reproducibility - - Returns - ------- - T : array-like, shape (`ns`, `nt`) - Optimal coupling between the two spaces - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ - C1, C2, p, q = list_to_array(C1, C2, p, q) - nx = get_backend(C1, C2, p, q) - - len_p = p.shape[0] - len_q = q.shape[0] - - generator = check_random_state(random_state) - - index = np.zeros(2, dtype=int) - - # Initialize with default marginal - index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) - index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q)) - T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)) - - best_gw_dist_estimated = np.inf - for cpt in range(max_iter): - index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) - T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,)) - index[1] = generator.choice( - len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0)) - ) - - if alpha == 1: - T = nx.tocsr( - emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) - ) - else: - new_T = nx.tocsr( - emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) - ) - T = (1 - alpha) * T + alpha * new_T - # To limit the number of non 0, the values below the threshold are set to 0. - T = nx.eliminate_zeros(T, threshold=threshold_plan) - - if cpt % 10 == 0 or cpt == (max_iter - 1): - gw_dist_estimated = GW_distance_estimation( - C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T, std=False, random_state=generator - ) - - if gw_dist_estimated < best_gw_dist_estimated: - best_gw_dist_estimated = gw_dist_estimated - best_T = nx.copy(T) - - if verbose: - if cpt % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated)) - - if log: - log = {} - log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( - C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=best_T, random_state=generator - ) - return best_T, log - return best_T - - -def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, - nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False, - random_state=None): - r""" - Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe. - This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver. - - The function solves the following optimization problem: - - .. math:: - \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - - \mathbf{T}^T \mathbf{1} &= \mathbf{q} - - \mathbf{T} &\geq 0 - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{p}`: distribution in the source space - - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` - Loss function used for the distance, the transport plan does not depend on the loss function - nb_samples_grad : int - Number of samples to approximate the gradient - epsilon : float - Weight of the Kullback-Leibler regularization - max_iter : int, optional - Max number of iterations - verbose : bool, optional - Print information along iterations - log : bool, optional - Gives the distance estimated and the standard deviation - random_state : int or RandomState instance, optional - Fix the seed for reproducibility - - Returns - ------- - T : array-like, shape (`ns`, `nt`) - Optimal coupling between the two spaces - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ - C1, C2, p, q = list_to_array(C1, C2, p, q) - nx = get_backend(C1, C2, p, q) - - len_p = p.shape[0] - len_q = q.shape[0] - - generator = check_random_state(random_state) - - # The most natural way to define nb_sample is with a simple integer. - if isinstance(nb_samples_grad, int): - if nb_samples_grad > len_p: - # As the sampling along the first dimension is done without replacement, the rest is reported to the second - # dimension. - nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p - else: - nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1 - else: - nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad - T = nx.outer(p, q) - # continue_loop allows to stop the loop if there is several successive small modification of T. - continue_loop = 0 - - # The gradient of GW is more complex if the two matrices are not symmetric. - C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) - - for cpt in range(max_iter): - index0 = generator.choice( - len_p, size=nb_samples_grad_p, p=nx.to_numpy(p), replace=False - ) - Lik = 0 - for i, index0_i in enumerate(index0): - index1 = generator.choice( - len_q, size=nb_samples_grad_q, - p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])), - replace=False - ) - # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. - if (not C_are_symmetric) and generator.rand(1) > 0.5: - Lik += nx.mean(loss_fun( - C1[:, [index0[i]] * nb_samples_grad_q][:, None, :], - C2[:, index1][None, :, :] - ), axis=2) - else: - Lik += nx.mean(loss_fun( - C1[[index0[i]] * nb_samples_grad_q, :][:, :, None], - C2[index1, :][:, None, :] - ), axis=0) - - max_Lik = nx.max(Lik) - if max_Lik == 0: - continue - # This division by the max is here to facilitate the choice of epsilon. - Lik /= max_Lik - - if epsilon > 0: - # Set to infinity all the numbers below exp(-200) to avoid log of 0. - log_T = nx.log(nx.clip(T, np.exp(-200), 1)) - log_T = nx.where(log_T == -200, -np.inf, log_T) - Lik = Lik - epsilon * log_T - - try: - new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon) - except (RuntimeWarning, UserWarning): - print("Warning catched in Sinkhorn: Return last stable T") - break - else: - new_T = emd(a=p, b=q, M=Lik) - - change_T = nx.mean((T - new_T) ** 2) - if change_T <= 10e-20: - continue_loop += 1 - if continue_loop > 100: # Number max of low modifications of T - T = nx.copy(new_T) - break - else: - continue_loop = 0 - - if verbose and cpt % 10 == 0: - if cpt % 200 == 0: - print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, change_T)) - T = nx.copy(new_T) - - if log: - log = {} - log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( - C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T, random_state=generator - ) - return T, log - return T - - -def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, - max_iter=1000, tol=1e-9, verbose=False, log=False): - r""" - Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` - - The function solves the following optimization problem: - - .. math:: - \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) - - s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - - \mathbf{T}^T \mathbf{1} &= \mathbf{q} - - \mathbf{T} &\geq 0 - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{p}`: distribution in the source space - - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices - - `H`: entropy - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : string - Loss function used for the solver either 'square_loss' or 'kl_loss' - epsilon : float - Regularization term >0 - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - Record log if True. - - Returns - ------- - T : array-like, shape (`ns`, `nt`) - Optimal coupling between the two spaces - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - C1, C2, p, q = list_to_array(C1, C2, p, q) - nx = get_backend(C1, C2, p, q) - - T = nx.outer(p, q) - - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - - cpt = 0 - err = 1 - - if log: - log = {'err': []} - - while (err > tol and cpt < max_iter): - - Tprev = T - - # compute the gradient - tens = gwggrad(constC, hC1, hC2, T) - - T = sinkhorn(p, q, tens, epsilon, method='sinkhorn') - - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err = nx.norm(T - Tprev) - - if log: - log['err'].append(err) - - if verbose: - if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - cpt += 1 - - if log: - log['gw_dist'] = gwloss(constC, hC1, hC2, T) - return T, log - else: - return T - - -def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, - max_iter=1000, tol=1e-9, verbose=False, log=False): - r""" - Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` - - The function solves the following optimization problem: - - .. math:: - GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) - \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{p}`: distribution in the source space - - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices - - `H`: entropy - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : str - Loss function used for the solver either 'square_loss' or 'kl_loss' - epsilon : float - Regularization term >0 - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - Record log if True. - - Returns - ------- - gw_dist : float - Gromov-Wasserstein distance - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - gw, logv = entropic_gromov_wasserstein( - C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True) - - logv['T'] = gw - - if log: - return logv['gw_dist'], logv - else: - return logv['gw_dist'] - - -def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, - max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None): - r""" - Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` - - The function solves the following optimization problem: - - .. math:: - - \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) - - Where : - - - :math:`\mathbf{C}_s`: metric cost matrix - - :math:`\mathbf{p}_s`: distribution - - Parameters - ---------- - N : int - Size of the targeted barycenter - Cs : list of S array-like of shape (ns,ns) - Metric cost matrices - ps : list of S array-like of shape (ns,) - Sample weights in the `S` spaces - p : array-like, shape(N,) - Weights in the targeted barycenter - lambdas : list of float - List of the `S` spaces' weights. - loss_fun : callable - Tensor-matrix multiplication function based on specific loss function. - update : callable - function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates - :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings - calculated at each iteration - epsilon : float - Regularization term >0 - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations. - log : bool, optional - Record log if True. - init_C : bool | array-like, shape (N, N) - Random initial value for the :math:`\mathbf{C}` matrix provided by user. - random_state : int or RandomState instance, optional - Fix the seed for reproducibility - - Returns - ------- - C : array-like, shape (`N`, `N`) - Similarity matrix in the barycenter space (permutated arbitrarily) - log : dict - Log dictionary of error during iterations. Return only if `log=True` in parameters. - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - """ - Cs = list_to_array(*Cs) - ps = list_to_array(*ps) - p = list_to_array(p) - nx = get_backend(*Cs, *ps, p) - - S = len(Cs) - - # Initialization of C : random SPD matrix (if not provided by user) - if init_C is None: - generator = check_random_state(random_state) - xalea = generator.randn(N, 2) - C = dist(xalea, xalea) - C /= C.max() - C = nx.from_numpy(C, type_as=p) - else: - C = init_C - - cpt = 0 - err = 1 - - error = [] - - while (err > tol) and (cpt < max_iter): - Cprev = C - - T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, - max_iter, 1e-4, verbose, log=False) for s in range(S)] - if loss_fun == 'square_loss': - C = update_square_loss(p, lambdas, T, Cs) - - elif loss_fun == 'kl_loss': - C = update_kl_loss(p, lambdas, T, Cs) - - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err = nx.norm(C - Cprev) - error.append(err) - - if verbose: - if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - cpt += 1 - - if log: - return C, {"err": error} - else: - return C - - -def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, - max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None): - r""" - Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` - - The function solves the following optimization problem with block coordinate descent: - - .. math:: - - \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) - - Where : - - - :math:`\mathbf{C}_s`: metric cost matrix - - :math:`\mathbf{p}_s`: distribution - - Parameters - ---------- - N : int - Size of the targeted barycenter - Cs : list of S array-like of shape (ns, ns) - Metric cost matrices - ps : list of S array-like of shape (ns,) - Sample weights in the `S` spaces - p : array-like, shape (N,) - Weights in the targeted barycenter - lambdas : list of float - List of the `S` spaces' weights - loss_fun : callable - tensor-matrix multiplication function based on specific loss function - update : callable - function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates - :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings - calculated at each iteration - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0). - verbose : bool, optional - Print information along iterations. - log : bool, optional - Record log if True. - init_C : bool | array-like, shape(N,N) - Random initial value for the :math:`\mathbf{C}` matrix provided by user. - random_state : int or RandomState instance, optional - Fix the seed for reproducibility - - Returns - ------- - C : array-like, shape (`N`, `N`) - Similarity matrix in the barycenter space (permutated arbitrarily) - log : dict - Log dictionary of error during iterations. Return only if `log=True` in parameters. - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - Cs = list_to_array(*Cs) - ps = list_to_array(*ps) - p = list_to_array(p) - nx = get_backend(*Cs, *ps, p) - - S = len(Cs) - - # Initialization of C : random SPD matrix (if not provided by user) - if init_C is None: - generator = check_random_state(random_state) - xalea = generator.randn(N, 2) - C = dist(xalea, xalea) - C /= C.max() - C = nx.from_numpy(C, type_as=p) - else: - C = init_C - - cpt = 0 - err = 1 - - error = [] - - while (err > tol and cpt < max_iter): - Cprev = C - - T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, - numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=False) for s in range(S)] - if loss_fun == 'square_loss': - C = update_square_loss(p, lambdas, T, Cs) - - elif loss_fun == 'kl_loss': - C = update_kl_loss(p, lambdas, T, Cs) - - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err = nx.norm(C - Cprev) - error.append(err) - - if verbose: - if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - cpt += 1 - - if log: - return C, {"err": error} - else: - return C - - -def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, - p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, - verbose=False, log=False, init_C=None, init_X=None, random_state=None): - r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` - - Parameters - ---------- - N : int - Desired number of samples of the target barycenter - Ys: list of array-like, each element has shape (ns,d) - Features of all samples - Cs : list of array-like, each element has shape (ns,ns) - Structure matrices of all samples - ps : list of array-like, each element has shape (ns,) - Masses of all samples. - lambdas : list of float - List of the `S` spaces' weights - alpha : float - Alpha parameter for the fgw distance - fixed_structure : bool - Whether to fix the structure of the barycenter during the updates - fixed_features : bool - Whether to fix the feature of the barycenter during the updates - loss_fun : str - Loss function used for the solver either 'square_loss' or 'kl_loss' - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0). - verbose : bool, optional - Print information along iterations. - log : bool, optional - Record log if True. - init_C : array-like, shape (N,N), optional - Initialization for the barycenters' structure matrix. If not set - a random init is used. - init_X : array-like, shape (N,d), optional - Initialization for the barycenters' features. If not set a - random init is used. - random_state : int or RandomState instance, optional - Fix the seed for reproducibility - - Returns - ------- - X : array-like, shape (`N`, `d`) - Barycenters' features - C : array-like, shape (`N`, `N`) - Barycenters' structure matrix - log : dict - Only returned when log=True. It contains the keys: - - - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices - - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`) - - - .. _references-fgw-barycenters: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain - and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - """ - Cs = list_to_array(*Cs) - ps = list_to_array(*ps) - Ys = list_to_array(*Ys) - p = list_to_array(p) - nx = get_backend(*Cs, *Ys, *ps) - - S = len(Cs) - d = Ys[0].shape[1] # dimension on the node features - if p is None: - p = nx.ones(N, type_as=Cs[0]) / N - - if fixed_structure: - if init_C is None: - raise UndefinedParameter('If C is fixed it must be initialized') - else: - C = init_C - else: - if init_C is None: - generator = check_random_state(random_state) - xalea = generator.randn(N, 2) - C = dist(xalea, xalea) - C = nx.from_numpy(C, type_as=ps[0]) - else: - C = init_C - - if fixed_features: - if init_X is None: - raise UndefinedParameter('If X is fixed it must be initialized') - else: - X = init_X - else: - if init_X is None: - X = nx.zeros((N, d), type_as=ps[0]) - else: - X = init_X - - T = [nx.outer(p, q) for q in ps] - - Ms = [dist(X, Ys[s]) for s in range(len(Ys))] - - cpt = 0 - err_feature = 1 - err_structure = 1 - - if log: - log_ = {} - log_['err_feature'] = [] - log_['err_structure'] = [] - log_['Ts_iter'] = [] - - while ((err_feature > tol or err_structure > tol) and cpt < max_iter): - Cprev = C - Xprev = X - - if not fixed_features: - Ys_temp = [y.T for y in Ys] - X = update_feature_matrix(lambdas, Ys_temp, T, p).T - - Ms = [dist(X, Ys[s]) for s in range(len(Ys))] - - if not fixed_structure: - if loss_fun == 'square_loss': - T_temp = [t.T for t in T] - C = update_structure_matrix(p, lambdas, T_temp, Cs) - - T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, - numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] - - # T is N,ns - err_feature = nx.norm(X - nx.reshape(Xprev, (N, d))) - err_structure = nx.norm(C - Cprev) - if log: - log_['err_feature'].append(err_feature) - log_['err_structure'].append(err_structure) - log_['Ts_iter'].append(T) - - if verbose: - if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err_structure)) - print('{:5d}|{:8e}|'.format(cpt, err_feature)) - - cpt += 1 - - if log: - log_['T'] = T # from target to Ys - log_['p'] = p - log_['Ms'] = Ms - - if log: - return X, C, log_ - else: - return X, C - - -def update_structure_matrix(p, lambdas, T, Cs): - r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. - - It is calculated at each iteration - - Parameters - ---------- - p : array-like, shape (N,) - Masses in the targeted barycenter. - lambdas : list of float - List of the `S` spaces' weights. - T : list of S array-like of shape (ns, N) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. - Cs : list of S array-like, shape (ns, ns) - Metric cost matrices. - - Returns - ------- - C : array-like, shape (`nt`, `nt`) - Updated :math:`\mathbf{C}` matrix. - """ - p = list_to_array(p) - T = list_to_array(*T) - Cs = list_to_array(*Cs) - nx = get_backend(*Cs, *T, p) - - tmpsum = sum([ - lambdas[s] * nx.dot( - nx.dot(T[s].T, Cs[s]), - T[s] - ) for s in range(len(T)) - ]) - ppt = nx.outer(p, p) - return tmpsum / ppt - - -def update_feature_matrix(lambdas, Ys, Ts, p): - r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. - - - See "Solving the barycenter problem with Block Coordinate Descent (BCD)" - in :ref:`[24] ` calculated at each iteration - - Parameters - ---------- - p : array-like, shape (N,) - masses in the targeted barycenter - lambdas : list of float - List of the `S` spaces' weights - Ts : list of S array-like, shape (ns,N) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration - Ys : list of S array-like, shape (d,ns) - The features. - - Returns - ------- - X : array-like, shape (`d`, `N`) - - - .. _references-update-feature-matrix: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - """ - p = list_to_array(p) - Ts = list_to_array(*Ts) - Ys = list_to_array(*Ys) - nx = get_backend(*Ys, *Ts, p) - - p = 1. / p - tmpsum = sum([ - lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :] - for s in range(len(Ts)) - ]) - return tmpsum - - -def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True, - tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): - r""" - Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s` - - .. math:: - \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2 - - such that, :math:`\forall s \leq S` : - - - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` - - :math:`\mathbf{w_s} \geq \mathbf{0}_D` - - Where : - - - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. - - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. - - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` - - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. - - reg is the regularization coefficient. - - The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38] - - Parameters - ---------- - Cs : list of S symmetric array-like, shape (ns, ns) - List of Metric/Graph cost matrices of variable size (ns, ns). - D: int - Number of dictionary atoms to learn - nt: int - Number of samples within each dictionary atoms - reg : float, optional - Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. - ps : list of S array-like, shape (ns,), optional - Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. - q : array-like, shape (nt,), optional - Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. - epochs: int, optional - Number of epochs used to learn the dictionary. Default is 32. - batch_size: int, optional - Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. - learning_rate: float, optional - Learning rate used for the stochastic gradient descent. Default is 1. - Cdict_init: list of D array-like with shape (nt, nt), optional - Used to initialize the dictionary. - If set to None (Default), the dictionary will be initialized randomly. - Else Cdict must have shape (D, nt, nt) i.e match provided shape features. - projection: str , optional - If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary - Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' - log: bool, optional - If set to True, losses evolution by batches and epochs are tracked. Default is False. - use_adam_optimizer: bool, optional - If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. - Else perform SGD with fixed learning rate. Default is True. - tol_outer : float, optional - Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. - tol_inner : float, optional - Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. - max_iter_outer : int, optional - Maximum number of iterations for the BCD. Default is 20. - max_iter_inner : int, optional - Maximum number of iterations for the Conjugate Gradient. Default is 200. - verbose : bool, optional - Print the reconstruction loss every epoch. Default is False. - - Returns - ------- - - Cdict_best_state : D array-like, shape (D,nt,nt) - Metric/Graph cost matrices composing the dictionary. - The dictionary leading to the best loss over an epoch is saved and returned. - log: dict - If use_log is True, contains loss evolutions by batches and epochs. - References - ------- - - ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. - "Online Graph Dictionary Learning" - International Conference on Machine Learning (ICML). 2021. - """ - # Handle backend of non-optional arguments - Cs0 = Cs - nx = get_backend(*Cs0) - Cs = [nx.to_numpy(C) for C in Cs0] - dataset_size = len(Cs) - # Handle backend of optional arguments - if ps is None: - ps = [unif(C.shape[0]) for C in Cs] - else: - ps = [nx.to_numpy(p) for p in ps] - if q is None: - q = unif(nt) - else: - q = nx.to_numpy(q) - if Cdict_init is None: - # Initialize randomly structures of dictionary atoms based on samples - dataset_means = [C.mean() for C in Cs] - Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) - else: - Cdict = nx.to_numpy(Cdict_init).copy() - assert Cdict.shape == (D, nt, nt) - - if 'symmetric' in projection: - Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) - if 'nonnegative' in projection: - Cdict[Cdict < 0.] = 0 - if use_adam_optimizer: - adam_moments = _initialize_adam_optimizer(Cdict) - - log = {'loss_batches': [], 'loss_epochs': []} - const_q = q[:, None] * q[None, :] - Cdict_best_state = Cdict.copy() - loss_best_state = np.inf - if batch_size > dataset_size: - batch_size = dataset_size - iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) - - for epoch in range(epochs): - cumulated_loss_over_epoch = 0. - - for _ in range(iter_by_epoch): - # batch sampling - batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) - cumulated_loss_over_batch = 0. - unmixings = np.zeros((batch_size, D)) - Cs_embedded = np.zeros((batch_size, nt, nt)) - Ts = [None] * batch_size - - for batch_idx, C_idx in enumerate(batch): - # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch - unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing( - Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner, - max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner - ) - cumulated_loss_over_batch += current_loss - cumulated_loss_over_epoch += cumulated_loss_over_batch - - if use_log: - log['loss_batches'].append(cumulated_loss_over_batch) - - # Stochastic projected gradient step over dictionary atoms - grad_Cdict = np.zeros_like(Cdict) - for batch_idx, C_idx in enumerate(batch): - shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) - grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] - grad_Cdict *= 2 / batch_size - if use_adam_optimizer: - Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments) - else: - Cdict -= learning_rate * grad_Cdict - if 'symmetric' in projection: - Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) - if 'nonnegative' in projection: - Cdict[Cdict < 0.] = 0. - - if use_log: - log['loss_epochs'].append(cumulated_loss_over_epoch) - if loss_best_state > cumulated_loss_over_epoch: - loss_best_state = cumulated_loss_over_epoch - Cdict_best_state = Cdict.copy() - if verbose: - print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) - - return nx.from_numpy(Cdict_best_state), log - - -def _initialize_adam_optimizer(variable): - - # Initialization for our numpy implementation of adam optimizer - atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor - atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor - atoms_adam_count = 1 - - return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count} - - -def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09): - - adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad - adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2) - unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count']) - unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count']) - variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps) - adam_moments['count'] += 1 - - return variable, adam_moments - - -def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs): - r""" - Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`. - - .. math:: - \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 - - such that: - - - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` - - :math:`\mathbf{w} \geq \mathbf{0}_D` - - Where : - - - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. - - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt. - - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. - - reg is the regularization coefficient. - - The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 1. - - Parameters - ---------- - C : array-like, shape (ns, ns) - Metric/Graph cost matrix. - Cdict : D array-like, shape (D,nt,nt) - Metric/Graph cost matrices composing the dictionary on which to embed C. - reg : float, optional. - Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. - p : array-like, shape (ns,), optional - Distribution in the source space C. Default is None and corresponds to uniform distribution. - q : array-like, shape (nt,), optional - Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. - tol_outer : float, optional - Solver precision for the BCD algorithm. - tol_inner : float, optional - Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. - max_iter_outer : int, optional - Maximum number of iterations for the BCD. Default is 20. - max_iter_inner : int, optional - Maximum number of iterations for the Conjugate Gradient. Default is 200. - - Returns - ------- - w: array-like, shape (D,) - gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary. - Cembedded: array-like, shape (nt,nt) - embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. - T: array-like (ns, nt) - Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})` - current_loss: float - reconstruction error - References - ------- - - ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. - "Online Graph Dictionary Learning" - International Conference on Machine Learning (ICML). 2021. - """ - C0, Cdict0 = C, Cdict - nx = get_backend(C0, Cdict0) - C = nx.to_numpy(C0) - Cdict = nx.to_numpy(Cdict0) - if p is None: - p = unif(C.shape[0]) - else: - p = nx.to_numpy(p) - - if q is None: - q = unif(Cdict.shape[-1]) - else: - q = nx.to_numpy(q) - - T = p[:, None] * q[None, :] - D = len(Cdict) - - w = unif(D) # Initialize uniformly the unmixing w - Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) - - const_q = q[:, None] * q[None, :] - # Trackers for BCD convergence - convergence_criterion = np.inf - current_loss = 10**15 - outer_count = 0 - - while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): - previous_loss = current_loss - # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w - T, log = gromov_wasserstein(C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, log=True, armijo=False, **kwargs) - current_loss = log['gw_dist'] - if reg != 0: - current_loss -= reg * np.sum(w**2) - - # 2. Solve linear unmixing problem over w with a fixed transport plan T - w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing( - C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T, - starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs - ) - - if previous_loss != 0: - convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) - else: # handle numerical issues around 0 - convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) - outer_count += 1 - - return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss) - - -def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs): - r""" - Returns for a fixed admissible transport plan, - the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})` - - .. math:: - \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2 - - - Such that: - - - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` - - :math:`\mathbf{w} \geq \mathbf{0}_D` - - Where : - - - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. - - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points. - - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. - - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`. - - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`. - - reg is the regularization coefficient. - - The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38] - - Parameters - ---------- - - C : array-like, shape (ns, ns) - Metric/Graph cost matrix. - Cdict : list of D array-like, shape (nt,nt) - Metric/Graph cost matrices composing the dictionary on which to embed C. - Each matrix in the dictionary must have the same size (nt,nt). - Cembedded: array-like, shape (nt,nt) - Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. - w: array-like, shape (D,) - Linear unmixing of the input structure onto the dictionary - const_q: array-like, shape (nt,nt) - product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. - T: array-like, shape (ns,nt) - fixed transport plan between the input structure and its representation in the dictionary. - p : array-like, shape (ns,) - Distribution in the source space. - q : array-like, shape (nt,) - Distribution in the embedding space depicted by the dictionary. - reg : float, optional. - Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. - - Returns - ------- - w: ndarray (D,) - optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing. - """ - convergence_criterion = np.inf - current_loss = starting_loss - count = 0 - const_TCT = np.transpose(C.dot(T)).dot(T) - - while (convergence_criterion > tol) and (count < max_iter): - - previous_loss = current_loss - # 1) Compute gradient at current point w - grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) - grad_w -= 2 * reg * w - - # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w - min_ = np.min(grad_w) - x = (grad_w == min_).astype(np.float64) - x /= np.sum(x) - - # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c - gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg) - - # 4) Updates: w <-- (1-gamma)*w + gamma*x - w += gamma * (x - w) - Cembedded += gamma * Cembedded_diff - current_loss += a * (gamma**2) + b * gamma - - if previous_loss != 0: # not that the loss can be negative if reg >0 - convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) - else: # handle numerical issues around 0 - convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) - count += 1 - - return w, Cembedded, current_loss - - -def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs): - r""" - Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing - .. math:: - \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2 - - - Such that: - - - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` - - Parameters - ---------- - - w : array-like, shape (D,) - Unmixing. - grad_w : array-like, shape (D, D) - Gradient of the reconstruction loss with respect to w. - x: array-like, shape (D,) - Conditional gradient direction. - Cdict : list of D array-like, shape (nt,nt) - Metric/Graph cost matrices composing the dictionary on which to embed C. - Each matrix in the dictionary must have the same size (nt,nt). - Cembedded: array-like, shape (nt,nt) - Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. - const_q: array-like, shape (nt,nt) - product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. - const_TCT: array-like, shape (nt, nt) - :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. - Returns - ------- - gamma: float - Optimal value for the line-search step - a: float - Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss - b: float - Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss - Cembedded_diff: numpy array, shape (nt, nt) - Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. - reg : float, optional. - Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`. - """ - - # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c - Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) - Cembedded_diff = Cembedded_x - Cembedded - trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) - trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) - a = trace_diffx - trace_diffw - b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) - if reg != 0: - a -= reg * np.sum((x - w)**2) - b -= 2 * reg * np.sum(w * (x - w)) - - if a > 0: - gamma = min(1, max(0, - b / (2 * a))) - elif a + b < 0: - gamma = 1 - else: - gamma = 0 - - return gamma, a, b, Cembedded_diff - - -def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1., - Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False, - tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): - r""" - Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s` - - .. math:: - \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2 - - - Such that :math:`\forall s \leq S` : - - - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` - - :math:`\mathbf{w_s} \geq \mathbf{0}_D` - - Where : - - - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. - - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d. - - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. - - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. - - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` - - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. - - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein - - reg is the regularization coefficient. - - - The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38] - - Parameters - ---------- - Cs : list of S symmetric array-like, shape (ns, ns) - List of Metric/Graph cost matrices of variable size (ns,ns). - Ys : list of S array-like, shape (ns, d) - List of feature matrix of variable size (ns,d) with d fixed. - D: int - Number of dictionary atoms to learn - nt: int - Number of samples within each dictionary atoms - alpha : float - Trade-off parameter of Fused Gromov-Wasserstein - reg : float, optional - Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. - ps : list of S array-like, shape (ns,), optional - Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. - q : array-like, shape (nt,), optional - Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. - epochs: int, optional - Number of epochs used to learn the dictionary. Default is 32. - batch_size: int, optional - Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. - learning_rate_C: float, optional - Learning rate used for the stochastic gradient descent on Cdict. Default is 1. - learning_rate_Y: float, optional - Learning rate used for the stochastic gradient descent on Ydict. Default is 1. - Cdict_init: list of D array-like with shape (nt, nt), optional - Used to initialize the dictionary structures Cdict. - If set to None (Default), the dictionary will be initialized randomly. - Else Cdict must have shape (D, nt, nt) i.e match provided shape features. - Ydict_init: list of D array-like with shape (nt, d), optional - Used to initialize the dictionary features Ydict. - If set to None, the dictionary features will be initialized randomly. - Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features. - projection: str, optional - If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary - Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' - log: bool, optional - If set to True, losses evolution by batches and epochs are tracked. Default is False. - use_adam_optimizer: bool, optional - If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. - Else perform SGD with fixed learning rate. Default is True. - tol_outer : float, optional - Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. - tol_inner : float, optional - Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. - max_iter_outer : int, optional - Maximum number of iterations for the BCD. Default is 20. - max_iter_inner : int, optional - Maximum number of iterations for the Conjugate Gradient. Default is 200. - verbose : bool, optional - Print the reconstruction loss every epoch. Default is False. - - Returns - ------- - - Cdict_best_state : D array-like, shape (D,nt,nt) - Metric/Graph cost matrices composing the dictionary. - The dictionary leading to the best loss over an epoch is saved and returned. - Ydict_best_state : D array-like, shape (D,nt,d) - Feature matrices composing the dictionary. - The dictionary leading to the best loss over an epoch is saved and returned. - log: dict - If use_log is True, contains loss evolutions by batches and epoches. - References - ------- - - ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. - "Online Graph Dictionary Learning" - International Conference on Machine Learning (ICML). 2021. - """ - Cs0, Ys0 = Cs, Ys - nx = get_backend(*Cs0, *Ys0) - Cs = [nx.to_numpy(C) for C in Cs0] - Ys = [nx.to_numpy(Y) for Y in Ys0] - - d = Ys[0].shape[-1] - dataset_size = len(Cs) - - if ps is None: - ps = [unif(C.shape[0]) for C in Cs] - else: - ps = [nx.to_numpy(p) for p in ps] - if q is None: - q = unif(nt) - else: - q = nx.to_numpy(q) - - if Cdict_init is None: - # Initialize randomly structures of dictionary atoms based on samples - dataset_means = [C.mean() for C in Cs] - Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) - else: - Cdict = nx.to_numpy(Cdict_init).copy() - assert Cdict.shape == (D, nt, nt) - if Ydict_init is None: - # Initialize randomly features of dictionary atoms based on samples distribution by feature component - dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys]) - Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d)) - else: - Ydict = nx.to_numpy(Ydict_init).copy() - assert Ydict.shape == (D, nt, d) - - if 'symmetric' in projection: - Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) - if 'nonnegative' in projection: - Cdict[Cdict < 0.] = 0. - - if use_adam_optimizer: - adam_moments_C = _initialize_adam_optimizer(Cdict) - adam_moments_Y = _initialize_adam_optimizer(Ydict) - - log = {'loss_batches': [], 'loss_epochs': []} - const_q = q[:, None] * q[None, :] - diag_q = np.diag(q) - Cdict_best_state = Cdict.copy() - Ydict_best_state = Ydict.copy() - loss_best_state = np.inf - if batch_size > dataset_size: - batch_size = dataset_size - iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) - - for epoch in range(epochs): - cumulated_loss_over_epoch = 0. - - for _ in range(iter_by_epoch): - - # Batch iterations - batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) - cumulated_loss_over_batch = 0. - unmixings = np.zeros((batch_size, D)) - Cs_embedded = np.zeros((batch_size, nt, nt)) - Ys_embedded = np.zeros((batch_size, nt, d)) - Ts = [None] * batch_size - - for batch_idx, C_idx in enumerate(batch): - # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch - unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing( - Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q, - tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner - ) - cumulated_loss_over_batch += current_loss - cumulated_loss_over_epoch += cumulated_loss_over_batch - if use_log: - log['loss_batches'].append(cumulated_loss_over_batch) - - # Stochastic projected gradient step over dictionary atoms - grad_Cdict = np.zeros_like(Cdict) - grad_Ydict = np.zeros_like(Ydict) - - for batch_idx, C_idx in enumerate(batch): - shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) - shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx]) - grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] - grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :] - grad_Cdict *= 2 / batch_size - grad_Ydict *= 2 / batch_size - - if use_adam_optimizer: - Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C) - Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y) - else: - Cdict -= learning_rate_C * grad_Cdict - Ydict -= learning_rate_Y * grad_Ydict - - if 'symmetric' in projection: - Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) - if 'nonnegative' in projection: - Cdict[Cdict < 0.] = 0. - - if use_log: - log['loss_epochs'].append(cumulated_loss_over_epoch) - if loss_best_state > cumulated_loss_over_epoch: - loss_best_state = cumulated_loss_over_epoch - Cdict_best_state = Cdict.copy() - Ydict_best_state = Ydict.copy() - if verbose: - print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) - - return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log - - -def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs): - r""" - Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` - - .. math:: - \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 - - such that, :math:`\forall s \leq S` : - - - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` - - :math:`\mathbf{w_s} \geq \mathbf{0}_D` - - Where : - - - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. - - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. - - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. - - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. - - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` - - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. - - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein - - reg is the regularization coefficient. - - The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 6. - - Parameters - ---------- - C : array-like, shape (ns, ns) - Metric/Graph cost matrix. - Y : array-like, shape (ns, d) - Feature matrix. - Cdict : D array-like, shape (D,nt,nt) - Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). - Ydict : D array-like, shape (D,nt,d) - Feature matrices composing the dictionary on which to embed (C,Y). - alpha: float, - Trade-off parameter of Fused Gromov-Wasserstein. - reg : float, optional - Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. - p : array-like, shape (ns,), optional - Distribution in the source space C. Default is None and corresponds to uniform distribution. - q : array-like, shape (nt,), optional - Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. - tol_outer : float, optional - Solver precision for the BCD algorithm. - tol_inner : float, optional - Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. - max_iter_outer : int, optional - Maximum number of iterations for the BCD. Default is 20. - max_iter_inner : int, optional - Maximum number of iterations for the Conjugate Gradient. Default is 200. - - Returns - ------- - w: array-like, shape (D,) - fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary. - Cembedded: array-like, shape (nt,nt) - embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. - Yembedded: array-like, shape (nt,d) - embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`. - T: array-like (ns,nt) - Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`. - current_loss: float - reconstruction error - References - ------- - - ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. - "Online Graph Dictionary Learning" - International Conference on Machine Learning (ICML). 2021. - """ - C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict - nx = get_backend(C0, Y0, Cdict0, Ydict0) - C = nx.to_numpy(C0) - Y = nx.to_numpy(Y0) - Cdict = nx.to_numpy(Cdict0) - Ydict = nx.to_numpy(Ydict0) - - if p is None: - p = unif(C.shape[0]) - else: - p = nx.to_numpy(p) - if q is None: - q = unif(Cdict.shape[-1]) - else: - q = nx.to_numpy(q) - - T = p[:, None] * q[None, :] - D = len(Cdict) - d = Y.shape[-1] - w = unif(D) # Initialize with uniform weights - ns = C.shape[-1] - nt = Cdict.shape[-1] - - # modeling (C,Y) - Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) - Yembedded = np.sum(w[:, None, None] * Ydict, axis=0) - - # constants depending on q - const_q = q[:, None] * q[None, :] - diag_q = np.diag(q) - # Trackers for BCD convergence - convergence_criterion = np.inf - current_loss = 10**15 - outer_count = 0 - Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix - - while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): - previous_loss = current_loss - - # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w - Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T) - M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features - T, log = fused_gromov_wasserstein(M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, armijo=False, G0=T, log=True) - current_loss = log['fgw_dist'] - if reg != 0: - current_loss -= reg * np.sum(w**2) - - # 2. Solve linear unmixing problem over w with a fixed transport plan T - w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, - T, p, q, const_q, diag_q, current_loss, alpha, reg, - tol=tol_inner, max_iter=max_iter_inner, **kwargs) - if previous_loss != 0: - convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) - else: - convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) - outer_count += 1 - - return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss) - - -def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs): - r""" - Returns for a fixed admissible transport plan, - the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})` - - .. math:: - \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2 - - Such that : - - - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` - - :math:`\mathbf{w} \geq \mathbf{0}_D` - - Where : - - - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. - - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. - - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. - - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. - - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` - - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. - - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}` - - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein - - reg is the regularization coefficient. - - The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38], algorithm 7. - - Parameters - ---------- - - C : array-like, shape (ns, ns) - Metric/Graph cost matrix. - Y : array-like, shape (ns, d) - Feature matrix. - Cdict : list of D array-like, shape (nt,nt) - Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). - Each matrix in the dictionary must have the same size (nt,nt). - Ydict : list of D array-like, shape (nt,d) - Feature matrices composing the dictionary on which to embed (C,Y). - Each matrix in the dictionary must have the same size (nt,d). - Cembedded: array-like, shape (nt,nt) - Embedded structure of (C,Y) onto the dictionary - Yembedded: array-like, shape (nt,d) - Embedded features of (C,Y) onto the dictionary - w: array-like, shape (n_D,) - Linear unmixing of (C,Y) onto (Cdict,Ydict) - const_q: array-like, shape (nt,nt) - product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution. - diag_q: array-like, shape (nt,nt) - diagonal matrix with values of q on the diagonal. - T: array-like, shape (ns,nt) - fixed transport plan between (C,Y) and its model - p : array-like, shape (ns,) - Distribution in the source space (C,Y). - q : array-like, shape (nt,) - Distribution in the embedding space depicted by the dictionary. - alpha: float, - Trade-off parameter of Fused Gromov-Wasserstein. - reg : float, optional - Coefficient of the negative quadratic regularization used to promote sparsity of w. - - Returns - ------- - w: ndarray (D,) - linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing. - """ - convergence_criterion = np.inf - current_loss = starting_loss - count = 0 - const_TCT = np.transpose(C.dot(T)).dot(T) - ones_ns_d = np.ones(Y.shape) - - while (convergence_criterion > tol) and (count < max_iter): - previous_loss = current_loss - - # 1) Compute gradient at current point w - # structure - grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) - # feature - grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2)) - grad_w -= reg * w - grad_w *= 2 - - # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w - min_ = np.min(grad_w) - x = (grad_w == min_).astype(np.float64) - x /= np.sum(x) - - # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c - gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg) - - # 4) Updates: w <-- (1-gamma)*w + gamma*x - w += gamma * (x - w) - Cembedded += gamma * Cembedded_diff - Yembedded += gamma * Yembedded_diff - current_loss += a * (gamma**2) + b * gamma - - if previous_loss != 0: - convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) - else: - convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) - count += 1 - - return w, Cembedded, Yembedded, current_loss - - -def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs): - r""" - Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing - .. math:: - \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2 - - - Such that : - - - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` - - Parameters - ---------- - - w : array-like, shape (D,) - Unmixing. - grad_w : array-like, shape (D, D) - Gradient of the reconstruction loss with respect to w. - x: array-like, shape (D,) - Conditional gradient direction. - Y: arrat-like, shape (ns,d) - Feature matrix of the input space - Cdict : list of D array-like, shape (nt, nt) - Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). - Each matrix in the dictionary must have the same size (nt,nt). - Ydict : list of D array-like, shape (nt, d) - Feature matrices composing the dictionary on which to embed (C,Y). - Each matrix in the dictionary must have the same size (nt,d). - Cembedded: array-like, shape (nt, nt) - Embedded structure of (C,Y) onto the dictionary - Yembedded: array-like, shape (nt, d) - Embedded features of (C,Y) onto the dictionary - T: array-like, shape (ns, nt) - Fixed transport plan between (C,Y) and its current model. - const_q: array-like, shape (nt,nt) - product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. - const_TCT: array-like, shape (nt, nt) - :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. - ones_ns_d: array-like, shape (ns, d) - :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations. - alpha: float, - Trade-off parameter of Fused Gromov-Wasserstein. - reg : float, optional - Coefficient of the negative quadratic regularization used to promote sparsity of w. - - Returns - ------- - gamma: float - Optimal value for the line-search step - a: float - Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss - b: float - Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss - Cembedded_diff: numpy array, shape (nt, nt) - Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. - Yembedded_diff: numpy array, shape (nt, nt) - Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. - """ - # polynomial coefficients from quadratic objective (with respect to w) on structures - Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) - Cembedded_diff = Cembedded_x - Cembedded - trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) - trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) - # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss - a_gw = trace_diffx - trace_diffw - b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) - - # polynomial coefficient from quadratic objective (with respect to w) on features - Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0) - Yembedded_diff = Yembedded_x - Yembedded - # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss - a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T) - b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T))) - - a = alpha * a_gw + (1 - alpha) * a_w - b = alpha * b_gw + (1 - alpha) * b_w - if reg != 0: - a -= reg * np.sum((x - w)**2) - b -= 2 * reg * np.sum(w * (x - w)) - if a > 0: - gamma = min(1, max(0, -b / (2 * a))) - elif a + b < 0: - gamma = 1 - else: - gamma = 0 - - return gamma, a, b, Cembedded_diff, Yembedded_diff diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py new file mode 100644 index 0000000..6184edf --- /dev/null +++ b/ot/gromov/__init__.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +""" +Solvers related to Gromov-Wasserstein problems. + +""" + +# Author: Remi Flamary +# Cedric Vincent-Cuaz +# +# License: MIT License + +# All submodules and packages +from ._utils import (init_matrix, tensor_product, gwloss, gwggrad, + update_square_loss, update_kl_loss, + init_matrix_semirelaxed) +from ._gw import (gromov_wasserstein, gromov_wasserstein2, + fused_gromov_wasserstein, fused_gromov_wasserstein2, + solve_gromov_linesearch, gromov_barycenters, fgw_barycenters, + update_structure_matrix, update_feature_matrix) +from ._bregman import (entropic_gromov_wasserstein, + entropic_gromov_wasserstein2, + entropic_gromov_barycenters) +from ._estimators import (GW_distance_estimation, pointwise_gromov_wasserstein, + sampled_gromov_wasserstein) +from ._semirelaxed import (semirelaxed_gromov_wasserstein, + semirelaxed_gromov_wasserstein2, + semirelaxed_fused_gromov_wasserstein, + semirelaxed_fused_gromov_wasserstein2, + solve_semirelaxed_gromov_linesearch) +from ._dictionary import (gromov_wasserstein_dictionary_learning, + gromov_wasserstein_linear_unmixing, + fused_gromov_wasserstein_dictionary_learning, + fused_gromov_wasserstein_linear_unmixing) + + +__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', + 'update_square_loss', 'update_kl_loss', 'init_matrix_semirelaxed', + 'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein', + 'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters', + 'fgw_barycenters', 'update_structure_matrix', 'update_feature_matrix', + 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2', + 'entropic_gromov_barycenters', 'GW_distance_estimation', + 'pointwise_gromov_wasserstein', 'sampled_gromov_wasserstein', + 'semirelaxed_gromov_wasserstein', 'semirelaxed_gromov_wasserstein2', + 'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2', + 'solve_semirelaxed_gromov_linesearch', 'gromov_wasserstein_dictionary_learning', + 'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning', + 'fused_gromov_wasserstein_linear_unmixing'] diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py new file mode 100644 index 0000000..5b2f959 --- /dev/null +++ b/ot/gromov/_bregman.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- +""" +Bregman projections solvers for entropic Gromov-Wasserstein +""" + +# Author: Erwan Vautier +# Nicolas Courty +# Rémi Flamary +# Titouan Vayer +# Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np + + +from ..bregman import sinkhorn +from ..utils import dist, list_to_array, check_random_state +from ..backend import get_backend + +from ._utils import init_matrix, gwloss, gwggrad +from ._utils import update_square_loss, update_kl_loss + + +def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, G0=None, + max_iter=1000, tol=1e-9, verbose=False, log=False): + r""" + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + - `H`: entropy + + .. note:: If the inner solver `ot.sinkhorn` did not convergence, the + optimal coupling :math:`\mathbf{T}` returned by this function does not + necessarily satisfy the marginal constraints + :math:`\mathbf{T}\mathbf{1}=\mathbf{p}` and + :math:`\mathbf{T}^T\mathbf{1}=\mathbf{q}`. So the returned + Gromov-Wasserstein loss does not necessarily satisfy distance + properties and may be negative. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : string + Loss function used for the solver either 'square_loss' or 'kl_loss' + epsilon : float + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + Record log if True. + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal coupling between the two spaces + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein + distance between networks and stable network invariants. + Information and Inference: A Journal of the IMA, 8(4), 757-787. + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + if G0 is None: + nx = get_backend(p, q, C1, C2) + G0 = nx.outer(p, q) + else: + nx = get_backend(p, q, C1, C2, G0) + T = G0 + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx) + if symmetric is None: + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + if not symmetric: + constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx) + cpt = 0 + err = 1 + + if log: + log = {'err': []} + + while (err > tol and cpt < max_iter): + + Tprev = T + + # compute the gradient + if symmetric: + tens = gwggrad(constC, hC1, hC2, T, nx) + else: + tens = 0.5 * (gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx)) + T = sinkhorn(p, q, tens, epsilon, method='sinkhorn') + + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = nx.norm(T - Tprev) + + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + + if log: + log['gw_dist'] = gwloss(constC, hC1, hC2, T, nx) + return T, log + else: + return T + + +def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None, G0=None, + max_iter=1000, tol=1e-9, verbose=False, log=False): + r""" + Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + + The function solves the following optimization problem: + + .. math:: + GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) + \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + - `H`: entropy + + .. note:: If the inner solver `ot.sinkhorn` did not convergence, the + optimal coupling :math:`\mathbf{T}` returned by this function does not + necessarily satisfy the marginal constraints + :math:`\mathbf{T}\mathbf{1}=\mathbf{p}` and + :math:`\mathbf{T}^T\mathbf{1}=\mathbf{q}`. So the returned + Gromov-Wasserstein loss does not necessarily satisfy distance + properties and may be negative. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : str + Loss function used for the solver either 'square_loss' or 'kl_loss' + epsilon : float + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + Record log if True. + + Returns + ------- + gw_dist : float + Gromov-Wasserstein distance + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + gw, logv = entropic_gromov_wasserstein( + C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter, tol, verbose, log=True) + + logv['T'] = gw + + if log: + return logv['gw_dist'], logv + else: + return logv['gw_dist'] + + +def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmetric=True, + max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None): + r""" + Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` + + The function solves the following optimization problem: + + .. math:: + + \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) + + Where : + + - :math:`\mathbf{C}_s`: metric cost matrix + - :math:`\mathbf{p}_s`: distribution + + Parameters + ---------- + N : int + Size of the targeted barycenter + Cs : list of S array-like of shape (ns,ns) + Metric cost matrices + ps : list of S array-like of shape (ns,) + Sample weights in the `S` spaces + p : array-like, shape(N,) + Weights in the targeted barycenter + lambdas : list of float + List of the `S` spaces' weights. + loss_fun : callable + Tensor-matrix multiplication function based on specific loss function. + update : callable + function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates + :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings + calculated at each iteration + epsilon : float + Regularization term >0 + symmetric : bool, optional. + Either structures are to be assumed symmetric or not. Default value is True. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. + init_C : bool | array-like, shape (N, N) + Random initial value for the :math:`\mathbf{C}` matrix provided by user. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + C : array-like, shape (`N`, `N`) + Similarity matrix in the barycenter space (permutated arbitrarily) + log : dict + Log dictionary of error during iterations. Return only if `log=True` in parameters. + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + """ + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + p = list_to_array(p) + nx = get_backend(*Cs, *ps, p) + + S = len(Cs) + + # Initialization of C : random SPD matrix (if not provided by user) + if init_C is None: + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) + C = dist(xalea, xalea) + C /= C.max() + C = nx.from_numpy(C, type_as=p) + else: + C = init_C + + cpt = 0 + err = 1 + + error = [] + + while (err > tol) and (cpt < max_iter): + Cprev = C + + T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, None, + max_iter, 1e-4, verbose, log=False) for s in range(S)] + if loss_fun == 'square_loss': + C = update_square_loss(p, lambdas, T, Cs) + + elif loss_fun == 'kl_loss': + C = update_kl_loss(p, lambdas, T, Cs) + + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = nx.norm(C - Cprev) + error.append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + + if log: + return C, {"err": error} + else: + return C diff --git a/ot/gromov/_dictionary.py b/ot/gromov/_dictionary.py new file mode 100644 index 0000000..5b32671 --- /dev/null +++ b/ot/gromov/_dictionary.py @@ -0,0 +1,1008 @@ +# -*- coding: utf-8 -*- +""" +(Fused) Gromov-Wasserstein dictionary learning. +""" + +# Author: Rémi Flamary +# Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np + + +from ..utils import unif +from ..backend import get_backend +from ._gw import gromov_wasserstein, fused_gromov_wasserstein + + +def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): + r""" + Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s` + + .. math:: + \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2 + + such that, :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - reg is the regularization coefficient. + + The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38]_ + + Parameters + ---------- + Cs : list of S symmetric array-like, shape (ns, ns) + List of Metric/Graph cost matrices of variable size (ns, ns). + D: int + Number of dictionary atoms to learn + nt: int + Number of samples within each dictionary atoms + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + ps : list of S array-like, shape (ns,), optional + Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. + q : array-like, shape (nt,), optional + Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. + epochs: int, optional + Number of epochs used to learn the dictionary. Default is 32. + batch_size: int, optional + Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. + learning_rate: float, optional + Learning rate used for the stochastic gradient descent. Default is 1. + Cdict_init: list of D array-like with shape (nt, nt), optional + Used to initialize the dictionary. + If set to None (Default), the dictionary will be initialized randomly. + Else Cdict must have shape (D, nt, nt) i.e match provided shape features. + projection: str , optional + If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary + Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' + log: bool, optional + If set to True, losses evolution by batches and epochs are tracked. Default is False. + use_adam_optimizer: bool, optional + If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. + Else perform SGD with fixed learning rate. Default is True. + tol_outer : float, optional + Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + verbose : bool, optional + Print the reconstruction loss every epoch. Default is False. + + Returns + ------- + + Cdict_best_state : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + log: dict + If use_log is True, contains loss evolutions by batches and epochs. + References + ------- + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + """ + # Handle backend of non-optional arguments + Cs0 = Cs + nx = get_backend(*Cs0) + Cs = [nx.to_numpy(C) for C in Cs0] + dataset_size = len(Cs) + # Handle backend of optional arguments + if ps is None: + ps = [unif(C.shape[0]) for C in Cs] + else: + ps = [nx.to_numpy(p) for p in ps] + if q is None: + q = unif(nt) + else: + q = nx.to_numpy(q) + if Cdict_init is None: + # Initialize randomly structures of dictionary atoms based on samples + dataset_means = [C.mean() for C in Cs] + Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) + else: + Cdict = nx.to_numpy(Cdict_init).copy() + assert Cdict.shape == (D, nt, nt) + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + symmetric = True + else: + symmetric = False + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0 + if use_adam_optimizer: + adam_moments = _initialize_adam_optimizer(Cdict) + + log = {'loss_batches': [], 'loss_epochs': []} + const_q = q[:, None] * q[None, :] + Cdict_best_state = Cdict.copy() + loss_best_state = np.inf + if batch_size > dataset_size: + batch_size = dataset_size + iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) + + for epoch in range(epochs): + cumulated_loss_over_epoch = 0. + + for _ in range(iter_by_epoch): + # batch sampling + batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) + cumulated_loss_over_batch = 0. + unmixings = np.zeros((batch_size, D)) + Cs_embedded = np.zeros((batch_size, nt, nt)) + Ts = [None] * batch_size + + for batch_idx, C_idx in enumerate(batch): + # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch + unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing( + Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner, + max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner, symmetric=symmetric, **kwargs + ) + cumulated_loss_over_batch += current_loss + cumulated_loss_over_epoch += cumulated_loss_over_batch + + if use_log: + log['loss_batches'].append(cumulated_loss_over_batch) + + # Stochastic projected gradient step over dictionary atoms + grad_Cdict = np.zeros_like(Cdict) + for batch_idx, C_idx in enumerate(batch): + shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) + grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] + grad_Cdict *= 2 / batch_size + if use_adam_optimizer: + Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments) + else: + Cdict -= learning_rate * grad_Cdict + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_log: + log['loss_epochs'].append(cumulated_loss_over_epoch) + if loss_best_state > cumulated_loss_over_epoch: + loss_best_state = cumulated_loss_over_epoch + Cdict_best_state = Cdict.copy() + if verbose: + print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) + + return nx.from_numpy(Cdict_best_state), log + + +def _initialize_adam_optimizer(variable): + + # Initialization for our numpy implementation of adam optimizer + atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor + atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor + atoms_adam_count = 1 + + return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count} + + +def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09): + + adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad + adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2) + unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count']) + unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count']) + variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps) + adam_moments['count'] += 1 + + return variable, adam_moments + + +def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, symmetric=None, **kwargs): + r""" + Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`. + + .. math:: + \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 + + such that: + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt. + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38]_ , algorithm 1. + + Parameters + ---------- + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Cdict : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. + p : array-like, shape (ns,), optional + Distribution in the source space C. Default is None and corresponds to uniform distribution. + q : array-like, shape (nt,), optional + Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. + tol_outer : float, optional + Solver precision for the BCD algorithm. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + + Returns + ------- + w: array-like, shape (D,) + gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary. + Cembedded: array-like, shape (nt,nt) + embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. + T: array-like (ns, nt) + Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})` + current_loss: float + reconstruction error + References + ------- + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + """ + C0, Cdict0 = C, Cdict + nx = get_backend(C0, Cdict0) + C = nx.to_numpy(C0) + Cdict = nx.to_numpy(Cdict0) + if p is None: + p = unif(C.shape[0]) + else: + p = nx.to_numpy(p) + + if q is None: + q = unif(Cdict.shape[-1]) + else: + q = nx.to_numpy(q) + + T = p[:, None] * q[None, :] + D = len(Cdict) + + w = unif(D) # Initialize uniformly the unmixing w + Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) + + const_q = q[:, None] * q[None, :] + # Trackers for BCD convergence + convergence_criterion = np.inf + current_loss = 10**15 + outer_count = 0 + + while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): + previous_loss = current_loss + # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w + T, log = gromov_wasserstein( + C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, + max_iter=max_iter_inner, tol_rel=tol_inner, tol_abs=0., log=True, armijo=False, symmetric=symmetric, **kwargs) + current_loss = log['gw_dist'] + if reg != 0: + current_loss -= reg * np.sum(w**2) + + # 2. Solve linear unmixing problem over w with a fixed transport plan T + w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing( + C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T, + starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs + ) + + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: # handle numerical issues around 0 + convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) + outer_count += 1 + + return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss) + + +def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs): + r""" + Returns for a fixed admissible transport plan, + the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})` + + .. math:: + \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2 + + + Such that: + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points. + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. + - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`. + - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`. + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]_ + + Parameters + ---------- + + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + Each matrix in the dictionary must have the same size (nt,nt). + Cembedded: array-like, shape (nt,nt) + Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. + w: array-like, shape (D,) + Linear unmixing of the input structure onto the dictionary + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + T: array-like, shape (ns,nt) + fixed transport plan between the input structure and its representation in the dictionary. + p : array-like, shape (ns,) + Distribution in the source space. + q : array-like, shape (nt,) + Distribution in the embedding space depicted by the dictionary. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. + + Returns + ------- + w: ndarray (D,) + optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing. + """ + convergence_criterion = np.inf + current_loss = starting_loss + count = 0 + const_TCT = np.transpose(C.dot(T)).dot(T) + + while (convergence_criterion > tol) and (count < max_iter): + + previous_loss = current_loss + # 1) Compute gradient at current point w + grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) + grad_w -= 2 * reg * w + + # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w + min_ = np.min(grad_w) + x = (grad_w == min_).astype(np.float64) + x /= np.sum(x) + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg) + + # 4) Updates: w <-- (1-gamma)*w + gamma*x + w += gamma * (x - w) + Cembedded += gamma * Cembedded_diff + current_loss += a * (gamma**2) + b * gamma + + if previous_loss != 0: # not that the loss can be negative if reg >0 + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: # handle numerical issues around 0 + convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) + count += 1 + + return w, Cembedded, current_loss + + +def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs): + r""" + Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing + .. math:: + \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2 + + + Such that: + + - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` + + Parameters + ---------- + + w : array-like, shape (D,) + Unmixing. + grad_w : array-like, shape (D, D) + Gradient of the reconstruction loss with respect to w. + x: array-like, shape (D,) + Conditional gradient direction. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + Each matrix in the dictionary must have the same size (nt,nt). + Cembedded: array-like, shape (nt,nt) + Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + const_TCT: array-like, shape (nt, nt) + :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. + Returns + ------- + gamma: float + Optimal value for the line-search step + a: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + b: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + Cembedded_diff: numpy array, shape (nt, nt) + Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`. + """ + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) + Cembedded_diff = Cembedded_x - Cembedded + trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) + trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) + a = trace_diffx - trace_diffw + b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) + if reg != 0: + a -= reg * np.sum((x - w)**2) + b -= 2 * reg * np.sum(w * (x - w)) + + if a > 0: + gamma = min(1, max(0, - b / (2 * a))) + elif a + b < 0: + gamma = 1 + else: + gamma = 0 + + return gamma, a, b, Cembedded_diff + + +def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1., + Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): + r""" + Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s` + + .. math:: + \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2 + + + Such that :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + + The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38]_ + + Parameters + ---------- + Cs : list of S symmetric array-like, shape (ns, ns) + List of Metric/Graph cost matrices of variable size (ns,ns). + Ys : list of S array-like, shape (ns, d) + List of feature matrix of variable size (ns,d) with d fixed. + D: int + Number of dictionary atoms to learn + nt: int + Number of samples within each dictionary atoms + alpha : float + Trade-off parameter of Fused Gromov-Wasserstein + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + ps : list of S array-like, shape (ns,), optional + Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. + q : array-like, shape (nt,), optional + Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. + epochs: int, optional + Number of epochs used to learn the dictionary. Default is 32. + batch_size: int, optional + Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. + learning_rate_C: float, optional + Learning rate used for the stochastic gradient descent on Cdict. Default is 1. + learning_rate_Y: float, optional + Learning rate used for the stochastic gradient descent on Ydict. Default is 1. + Cdict_init: list of D array-like with shape (nt, nt), optional + Used to initialize the dictionary structures Cdict. + If set to None (Default), the dictionary will be initialized randomly. + Else Cdict must have shape (D, nt, nt) i.e match provided shape features. + Ydict_init: list of D array-like with shape (nt, d), optional + Used to initialize the dictionary features Ydict. + If set to None, the dictionary features will be initialized randomly. + Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features. + projection: str, optional + If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary + Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' + log: bool, optional + If set to True, losses evolution by batches and epochs are tracked. Default is False. + use_adam_optimizer: bool, optional + If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. + Else perform SGD with fixed learning rate. Default is True. + tol_outer : float, optional + Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + verbose : bool, optional + Print the reconstruction loss every epoch. Default is False. + + Returns + ------- + + Cdict_best_state : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + Ydict_best_state : D array-like, shape (D,nt,d) + Feature matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + log: dict + If use_log is True, contains loss evolutions by batches and epoches. + References + ------- + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + """ + Cs0, Ys0 = Cs, Ys + nx = get_backend(*Cs0, *Ys0) + Cs = [nx.to_numpy(C) for C in Cs0] + Ys = [nx.to_numpy(Y) for Y in Ys0] + + d = Ys[0].shape[-1] + dataset_size = len(Cs) + + if ps is None: + ps = [unif(C.shape[0]) for C in Cs] + else: + ps = [nx.to_numpy(p) for p in ps] + if q is None: + q = unif(nt) + else: + q = nx.to_numpy(q) + + if Cdict_init is None: + # Initialize randomly structures of dictionary atoms based on samples + dataset_means = [C.mean() for C in Cs] + Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) + else: + Cdict = nx.to_numpy(Cdict_init).copy() + assert Cdict.shape == (D, nt, nt) + if Ydict_init is None: + # Initialize randomly features of dictionary atoms based on samples distribution by feature component + dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys]) + Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d)) + else: + Ydict = nx.to_numpy(Ydict_init).copy() + assert Ydict.shape == (D, nt, d) + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + symmetric = True + else: + symmetric = False + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_adam_optimizer: + adam_moments_C = _initialize_adam_optimizer(Cdict) + adam_moments_Y = _initialize_adam_optimizer(Ydict) + + log = {'loss_batches': [], 'loss_epochs': []} + const_q = q[:, None] * q[None, :] + diag_q = np.diag(q) + Cdict_best_state = Cdict.copy() + Ydict_best_state = Ydict.copy() + loss_best_state = np.inf + if batch_size > dataset_size: + batch_size = dataset_size + iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) + + for epoch in range(epochs): + cumulated_loss_over_epoch = 0. + + for _ in range(iter_by_epoch): + + # Batch iterations + batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) + cumulated_loss_over_batch = 0. + unmixings = np.zeros((batch_size, D)) + Cs_embedded = np.zeros((batch_size, nt, nt)) + Ys_embedded = np.zeros((batch_size, nt, d)) + Ts = [None] * batch_size + + for batch_idx, C_idx in enumerate(batch): + # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch + unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing( + Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q, + tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner, symmetric=symmetric, **kwargs + ) + cumulated_loss_over_batch += current_loss + cumulated_loss_over_epoch += cumulated_loss_over_batch + if use_log: + log['loss_batches'].append(cumulated_loss_over_batch) + + # Stochastic projected gradient step over dictionary atoms + grad_Cdict = np.zeros_like(Cdict) + grad_Ydict = np.zeros_like(Ydict) + + for batch_idx, C_idx in enumerate(batch): + shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) + shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx]) + grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] + grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :] + grad_Cdict *= 2 / batch_size + grad_Ydict *= 2 / batch_size + + if use_adam_optimizer: + Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C) + Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y) + else: + Cdict -= learning_rate_C * grad_Cdict + Ydict -= learning_rate_Y * grad_Ydict + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_log: + log['loss_epochs'].append(cumulated_loss_over_epoch) + if loss_best_state > cumulated_loss_over_epoch: + loss_best_state = cumulated_loss_over_epoch + Cdict_best_state = Cdict.copy() + Ydict_best_state = Ydict.copy() + if verbose: + print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) + + return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log + + +def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), + tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, symmetric=True, **kwargs): + r""" + Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` + + .. math:: + \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 + + such that, :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38]_, algorithm 6. + + Parameters + ---------- + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Y : array-like, shape (ns, d) + Feature matrix. + Cdict : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Ydict : D array-like, shape (D,nt,d) + Feature matrices composing the dictionary on which to embed (C,Y). + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + p : array-like, shape (ns,), optional + Distribution in the source space C. Default is None and corresponds to uniform distribution. + q : array-like, shape (nt,), optional + Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. + tol_outer : float, optional + Solver precision for the BCD algorithm. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + + Returns + ------- + w: array-like, shape (D,) + fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary. + Cembedded: array-like, shape (nt,nt) + embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. + Yembedded: array-like, shape (nt,d) + embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`. + T: array-like (ns,nt) + Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`. + current_loss: float + reconstruction error + References + ------- + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + """ + C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict + nx = get_backend(C0, Y0, Cdict0, Ydict0) + C = nx.to_numpy(C0) + Y = nx.to_numpy(Y0) + Cdict = nx.to_numpy(Cdict0) + Ydict = nx.to_numpy(Ydict0) + + if p is None: + p = unif(C.shape[0]) + else: + p = nx.to_numpy(p) + if q is None: + q = unif(Cdict.shape[-1]) + else: + q = nx.to_numpy(q) + + T = p[:, None] * q[None, :] + D = len(Cdict) + d = Y.shape[-1] + w = unif(D) # Initialize with uniform weights + ns = C.shape[-1] + nt = Cdict.shape[-1] + + # modeling (C,Y) + Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) + Yembedded = np.sum(w[:, None, None] * Ydict, axis=0) + + # constants depending on q + const_q = q[:, None] * q[None, :] + diag_q = np.diag(q) + # Trackers for BCD convergence + convergence_criterion = np.inf + current_loss = 10**15 + outer_count = 0 + Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix + + while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): + previous_loss = current_loss + + # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w + Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T) + M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features + T, log = fused_gromov_wasserstein( + M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, + max_iter=max_iter_inner, tol_rel=tol_inner, tol_abs=0., armijo=False, G0=T, log=True, symmetric=symmetric, **kwargs) + current_loss = log['fgw_dist'] + if reg != 0: + current_loss -= reg * np.sum(w**2) + + # 2. Solve linear unmixing problem over w with a fixed transport plan T + w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, + T, p, q, const_q, diag_q, current_loss, alpha, reg, + tol=tol_inner, max_iter=max_iter_inner, **kwargs) + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: + convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) + outer_count += 1 + + return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss) + + +def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs): + r""" + Returns for a fixed admissible transport plan, + the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})` + + .. math:: + \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2 + + Such that : + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}` + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]_, algorithm 7. + + Parameters + ---------- + + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Y : array-like, shape (ns, d) + Feature matrix. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,nt). + Ydict : list of D array-like, shape (nt,d) + Feature matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,d). + Cembedded: array-like, shape (nt,nt) + Embedded structure of (C,Y) onto the dictionary + Yembedded: array-like, shape (nt,d) + Embedded features of (C,Y) onto the dictionary + w: array-like, shape (n_D,) + Linear unmixing of (C,Y) onto (Cdict,Ydict) + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution. + diag_q: array-like, shape (nt,nt) + diagonal matrix with values of q on the diagonal. + T: array-like, shape (ns,nt) + fixed transport plan between (C,Y) and its model + p : array-like, shape (ns,) + Distribution in the source space (C,Y). + q : array-like, shape (nt,) + Distribution in the embedding space depicted by the dictionary. + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. + + Returns + ------- + w: ndarray (D,) + linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing. + """ + convergence_criterion = np.inf + current_loss = starting_loss + count = 0 + const_TCT = np.transpose(C.dot(T)).dot(T) + ones_ns_d = np.ones(Y.shape) + + while (convergence_criterion > tol) and (count < max_iter): + previous_loss = current_loss + + # 1) Compute gradient at current point w + # structure + grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) + # feature + grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2)) + grad_w -= reg * w + grad_w *= 2 + + # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w + min_ = np.min(grad_w) + x = (grad_w == min_).astype(np.float64) + x /= np.sum(x) + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg) + + # 4) Updates: w <-- (1-gamma)*w + gamma*x + w += gamma * (x - w) + Cembedded += gamma * Cembedded_diff + Yembedded += gamma * Yembedded_diff + current_loss += a * (gamma**2) + b * gamma + + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: + convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) + count += 1 + + return w, Cembedded, Yembedded, current_loss + + +def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs): + r""" + Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing + .. math:: + \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2 + + + Such that : + + - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` + + Parameters + ---------- + + w : array-like, shape (D,) + Unmixing. + grad_w : array-like, shape (D, D) + Gradient of the reconstruction loss with respect to w. + x: array-like, shape (D,) + Conditional gradient direction. + Y: arrat-like, shape (ns,d) + Feature matrix of the input space + Cdict : list of D array-like, shape (nt, nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,nt). + Ydict : list of D array-like, shape (nt, d) + Feature matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,d). + Cembedded: array-like, shape (nt, nt) + Embedded structure of (C,Y) onto the dictionary + Yembedded: array-like, shape (nt, d) + Embedded features of (C,Y) onto the dictionary + T: array-like, shape (ns, nt) + Fixed transport plan between (C,Y) and its current model. + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + const_TCT: array-like, shape (nt, nt) + :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. + ones_ns_d: array-like, shape (ns, d) + :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations. + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. + + Returns + ------- + gamma: float + Optimal value for the line-search step + a: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + b: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + Cembedded_diff: numpy array, shape (nt, nt) + Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + Yembedded_diff: numpy array, shape (nt, nt) + Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + """ + # polynomial coefficients from quadratic objective (with respect to w) on structures + Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) + Cembedded_diff = Cembedded_x - Cembedded + trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) + trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) + # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss + a_gw = trace_diffx - trace_diffw + b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) + + # polynomial coefficient from quadratic objective (with respect to w) on features + Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0) + Yembedded_diff = Yembedded_x - Yembedded + # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss + a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T) + b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T))) + + a = alpha * a_gw + (1 - alpha) * a_w + b = alpha * b_gw + (1 - alpha) * b_w + if reg != 0: + a -= reg * np.sum((x - w)**2) + b -= 2 * reg * np.sum(w * (x - w)) + if a > 0: + gamma = min(1, max(0, -b / (2 * a))) + elif a + b < 0: + gamma = 1 + else: + gamma = 0 + + return gamma, a, b, Cembedded_diff, Yembedded_diff diff --git a/ot/gromov/_estimators.py b/ot/gromov/_estimators.py new file mode 100644 index 0000000..0a29a91 --- /dev/null +++ b/ot/gromov/_estimators.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- +""" +Gromov-Wasserstein and Fused-Gromov-Wasserstein stochastic estimators. +""" + +# Author: Rémi Flamary +# Tanguy Kerdoncuff +# +# License: MIT License + +import numpy as np + + +from ..bregman import sinkhorn +from ..utils import list_to_array, check_random_state +from ..lp import emd_1d, emd +from ..backend import get_backend + + +def GW_distance_estimation(C1, C2, p, q, loss_fun, T, + nb_samples_p=None, nb_samples_q=None, std=True, random_state=None): + r""" + Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + with a fixed transport plan :math:`\mathbf{T}`. + + The function gives an unbiased approximation of the following equation: + + .. math:: + + GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - `L` : Loss function to account for the misfit between the similarity matrices + - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}` + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + T : csr or array-like, shape (ns, nt) + Transport plan matrix, either a sparse csr or a dense matrix + nb_samples_p : int, optional + `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}` + nb_samples_q : int, optional + `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first + std : bool, optional + Standard deviation associated with the prediction of the gromov-wasserstein cost + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + : float + Gromov-wasserstein cost + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + generator = check_random_state(random_state) + + len_p = p.shape[0] + len_q = q.shape[0] + + # It is always better to sample from the biggest distribution first. + if len_p < len_q: + p, q = q, p + len_p, len_q = len_q, len_p + C1, C2 = C2, C1 + T = T.T + + if nb_samples_p is None: + if nx.issparse(T): + # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced + nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p) + else: + nb_samples_p = len_p + else: + # The number of sample along the first dimension is without replacement. + nb_samples_p = min(nb_samples_p, len_p) + if nb_samples_q is None: + nb_samples_q = 1 + if std: + nb_samples_q = max(2, nb_samples_q) + + index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + + index_i = generator.choice( + len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False + ) + index_j = generator.choice( + len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False + ) + + for i in range(nb_samples_p): + if nx.issparse(T): + T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,)) + T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,)) + else: + T_indexi = T[index_i[i], :] + T_indexj = T[index_j[i], :] + # For each of the row sampled, the column is sampled. + index_k[i] = generator.choice( + len_q, + size=nb_samples_q, + p=nx.to_numpy(T_indexi / nx.sum(T_indexi)), + replace=True + ) + index_l[i] = generator.choice( + len_q, + size=nb_samples_q, + p=nx.to_numpy(T_indexj / nx.sum(T_indexj)), + replace=True + ) + + list_value_sample = nx.stack([ + loss_fun( + C1[np.ix_(index_i, index_j)], + C2[np.ix_(index_k[:, n], index_l[:, n])] + ) for n in range(nb_samples_q) + ], axis=2) + + if std: + std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5 + return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) + else: + return nx.mean(list_value_sample) + + +def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, + alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None): + r""" + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe. + This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations. + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + alpha : float + Step of the Frank-Wolfe algorithm, should be between 0 and 1 + max_iter : int, optional + Max number of iterations + threshold_plan : float, optional + Deleting very small values in the transport plan. If above zero, it violates the marginal constraints. + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + len_p = p.shape[0] + len_q = q.shape[0] + + generator = check_random_state(random_state) + + index = np.zeros(2, dtype=int) + + # Initialize with default marginal + index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) + index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q)) + T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)) + + best_gw_dist_estimated = np.inf + for cpt in range(max_iter): + index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) + T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,)) + index[1] = generator.choice( + len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0)) + ) + + if alpha == 1: + T = nx.tocsr( + emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) + ) + else: + new_T = nx.tocsr( + emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) + ) + T = (1 - alpha) * T + alpha * new_T + # To limit the number of non 0, the values below the threshold are set to 0. + T = nx.eliminate_zeros(T, threshold=threshold_plan) + + if cpt % 10 == 0 or cpt == (max_iter - 1): + gw_dist_estimated = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, std=False, random_state=generator + ) + + if gw_dist_estimated < best_gw_dist_estimated: + best_gw_dist_estimated = gw_dist_estimated + best_T = nx.copy(T) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated)) + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=best_T, random_state=generator + ) + return best_T, log + return best_T + + +def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, + nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False, + random_state=None): + r""" + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe. + This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver. + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + nb_samples_grad : int + Number of samples to approximate the gradient + epsilon : float + Weight of the Kullback-Leibler regularization + max_iter : int, optional + Max number of iterations + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + len_p = p.shape[0] + len_q = q.shape[0] + + generator = check_random_state(random_state) + + # The most natural way to define nb_sample is with a simple integer. + if isinstance(nb_samples_grad, int): + if nb_samples_grad > len_p: + # As the sampling along the first dimension is done without replacement, the rest is reported to the second + # dimension. + nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p + else: + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1 + else: + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad + T = nx.outer(p, q) + # continue_loop allows to stop the loop if there is several successive small modification of T. + continue_loop = 0 + + # The gradient of GW is more complex if the two matrices are not symmetric. + C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) + + for cpt in range(max_iter): + index0 = generator.choice( + len_p, size=nb_samples_grad_p, p=nx.to_numpy(p), replace=False + ) + Lik = 0 + for i, index0_i in enumerate(index0): + index1 = generator.choice( + len_q, size=nb_samples_grad_q, + p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])), + replace=False + ) + # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. + if (not C_are_symmetric) and generator.rand(1) > 0.5: + Lik += nx.mean(loss_fun( + C1[:, [index0[i]] * nb_samples_grad_q][:, None, :], + C2[:, index1][None, :, :] + ), axis=2) + else: + Lik += nx.mean(loss_fun( + C1[[index0[i]] * nb_samples_grad_q, :][:, :, None], + C2[index1, :][:, None, :] + ), axis=0) + + max_Lik = nx.max(Lik) + if max_Lik == 0: + continue + # This division by the max is here to facilitate the choice of epsilon. + Lik /= max_Lik + + if epsilon > 0: + # Set to infinity all the numbers below exp(-200) to avoid log of 0. + log_T = nx.log(nx.clip(T, np.exp(-200), 1)) + log_T = nx.where(log_T == -200, -np.inf, log_T) + Lik = Lik - epsilon * log_T + + try: + new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon) + except (RuntimeWarning, UserWarning): + print("Warning catched in Sinkhorn: Return last stable T") + break + else: + new_T = emd(a=p, b=q, M=Lik) + + change_T = nx.mean((T - new_T) ** 2) + if change_T <= 10e-20: + continue_loop += 1 + if continue_loop > 100: # Number max of low modifications of T + T = nx.copy(new_T) + break + else: + continue_loop = 0 + + if verbose and cpt % 10 == 0: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, change_T)) + T = nx.copy(new_T) + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, random_state=generator + ) + return T, log + return T diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py new file mode 100644 index 0000000..c6e4076 --- /dev/null +++ b/ot/gromov/_gw.py @@ -0,0 +1,978 @@ +# -*- coding: utf-8 -*- +""" +Gromov-Wasserstein and Fused-Gromov-Wasserstein conditional gradient solvers. +""" + +# Author: Erwan Vautier +# Nicolas Courty +# Rémi Flamary +# Titouan Vayer +# Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np + + +from ..utils import dist, UndefinedParameter, list_to_array +from ..optim import cg, line_search_armijo, solve_1d_linesearch_quad +from ..utils import check_random_state +from ..backend import get_backend, NumpyBackend + +from ._utils import init_matrix, gwloss, gwggrad +from ._utils import update_square_loss, update_kl_loss + + +def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss' + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + armijo : bool, optional + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Coupling between the two spaces that minimizes: + + :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}` + log : dict + Convergence information and loss. + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the + metric approach to object matching. Foundations of computational + mathematics 11.4 (2011): 417-487. + + .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein + distance between networks and stable network invariants. + Information and Inference: A Journal of the IMA, 8(4), 757-787. + """ + p, q = list_to_array(p, q) + p0, q0, C10, C20 = p, q, C1, C2 + if G0 is None: + nx = get_backend(p0, q0, C10, C20) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, G0_) + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + if symmetric is None: + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) + # cg for GW is implemented using numpy on CPU + np_ = NumpyBackend() + + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_) + + def f(G): + return gwloss(constC, hC1, hC2, G, np_) + + if symmetric: + def df(G): + return gwggrad(constC, hC1, hC2, G, np_) + else: + constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, np_) + + def df(G): + return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) + if loss_fun == 'kl_loss': + armijo = True # there is no closed form line-search with KL + + if armijo: + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) + else: + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_, **kwargs) + if log: + res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log + else: + return nx.from_numpy(cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10) + + +def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + + The function solves the following optimization problem: + + .. math:: + GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity + matrices + + Note that when using backends, this loss function is differentiable wrt the + matrices (C1, C2) and weights (p, q) for quadratic loss using the gradients from [38]_. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space. + q : array-like, shape (nt,) + Distribution in the target space. + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss' + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + armijo : bool, optional + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + gw_dist : float + Gromov-Wasserstein distance + log : dict + convergence information and Coupling marix + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the + metric approach to object matching. Foundations of computational + mathematics 11.4 (2011): 417-487. + + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + + .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein + distance between networks and stable network invariants. + Information and Inference: A Journal of the IMA, 8(4), 757-787. + """ + # simple get_backend as the full one will be handled in gromov_wasserstein + nx = get_backend(C1, C2) + + T, log_gw = gromov_wasserstein( + C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0, + max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) + + log_gw['T'] = T + gw = log_gw['gw_dist'] + + if loss_fun == 'square_loss': + gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) + gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) + gw = nx.set_gradients(gw, (p, q, C1, C2), + (log_gw['u'] - nx.mean(log_gw['u']), + log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2)) + + if log: + return gw, log_gw + else: + return gw + + +def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5, + armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Computes the FGW transport between two graphs (see :ref:`[24] `) + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix representative of the structure in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix representative of the structure in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : str, optional + Loss function used for the solver + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + armijo : bool, optional + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + log : bool, optional + record log if True + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + gamma : array-like, shape (`ns`, `nt`) + Optimal transportation matrix for the given parameters. + log : dict + Log dictionary return only if log==True in parameters. + + + .. _references-fused-gromov-wasserstein: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + + .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein + distance between networks and stable network invariants. + Information and Inference: A Journal of the IMA, 8(4), 757-787. + """ + p, q = list_to_array(p, q) + p0, q0, C10, C20, M0 = p, q, C1, C2, M + if G0 is None: + nx = get_backend(p0, q0, C10, C20, M0) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, M0, G0_) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + M = nx.to_numpy(M0) + + if symmetric is None: + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) + # cg for GW is implemented using numpy on CPU + np_ = NumpyBackend() + + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_) + + def f(G): + return gwloss(constC, hC1, hC2, G, np_) + + if symmetric: + def df(G): + return gwggrad(constC, hC1, hC2, G, np_) + else: + constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, np_) + + def df(G): + return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) + + if loss_fun == 'kl_loss': + armijo = True # there is no closed form line-search with KL + + if armijo: + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) + else: + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs) + if log: + res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log + else: + return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10) + + +def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5, + armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Computes the FGW distance between two graphs see (see :ref:`[24] `) + + .. math:: + \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + The algorithm used for solving the problem is conditional gradient as + discussed in :ref:`[24] ` + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + + Note that when using backends, this loss function is differentiable wrt the + matrices (C1, C2, M) and weights (p, q) for quadratic loss using the gradients from [38]_. + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix representative of the structure in the source space. + C2 : array-like, shape (nt, nt) + Metric cost matrix representative of the structure in the target space. + p : array-like, shape (ns,) + Distribution in the source space. + q : array-like, shape (nt,) + Distribution in the target space. + loss_fun : str, optional + Loss function used for the solver. + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + armijo : bool, optional + If True the step of the line-search is found via an armijo research. + Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + log : bool, optional + Record log if True. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + Parameters can be directly passed to the ot.optim.cg solver. + + Returns + ------- + fgw-distance : float + Fused gromov wasserstein distance for the given parameters. + log : dict + Log dictionary return only if log==True in parameters. + + + .. _references-fused-gromov-wasserstein2: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + + .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein + distance between networks and stable network invariants. + Information and Inference: A Journal of the IMA, 8(4), 757-787. + """ + nx = get_backend(C1, C2, M) + + T, log_fgw = fused_gromov_wasserstein( + M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True, + max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) + + fgw_dist = log_fgw['fgw_dist'] + log_fgw['T'] = T + + if loss_fun == 'square_loss': + gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) + gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) + fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), + (log_fgw['u'] - nx.mean(log_fgw['u']), + log_fgw['v'] - nx.mean(log_fgw['v']), + alpha * gC1, alpha * gC2, (1 - alpha) * T)) + + if log: + return fgw_dist, log_fgw + else: + return fgw_dist + + +def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, + alpha_min=None, alpha_max=None, nx=None, **kwargs): + """ + Solve the linesearch in the FW iterations + + Parameters + ---------- + + G : array-like, shape(ns,nt) + The transport map at a given iteration of the FW + deltaG : array-like (ns,nt) + Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration + cost_G : float + Value of the cost at `G` + C1 : array-like (ns,ns), optional + Structure matrix in the source domain. + C2 : array-like (nt,nt), optional + Structure matrix in the target domain. + M : array-like (ns,nt) + Cost matrix between the features. + reg : float + Regularization parameter. + alpha_min : float, optional + Minimum value for alpha + alpha_max : float, optional + Maximum value for alpha + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + alpha : float + The optimal step size of the FW + fc : int + nb of function call. Useless here + cost_G : float + The value of the cost for the next iteration + + + .. _references-solve-linesearch: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + if nx is None: + G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M) + + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(G, deltaG, C1, C2) + else: + nx = get_backend(G, deltaG, C1, C2, M) + + dot = nx.dot(nx.dot(C1, deltaG), C2.T) + a = -2 * reg * nx.sum(dot * deltaG) + b = nx.sum(M * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) + + alpha = solve_1d_linesearch_quad(a, b) + if alpha_min is not None or alpha_max is not None: + alpha = np.clip(alpha, alpha_min, alpha_max) + + # the new cost is deduced from the line search quadratic function + cost_G = cost_G + a * (alpha ** 2) + b * alpha + + return alpha, 1, cost_G + + +def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=False, + max_iter=1000, tol=1e-9, verbose=False, log=False, + init_C=None, random_state=None, **kwargs): + r""" + Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` + + The function solves the following optimization problem with block coordinate descent: + + .. math:: + + \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) + + Where : + + - :math:`\mathbf{C}_s`: metric cost matrix + - :math:`\mathbf{p}_s`: distribution + + Parameters + ---------- + N : int + Size of the targeted barycenter + Cs : list of S array-like of shape (ns, ns) + Metric cost matrices + ps : list of S array-like of shape (ns,) + Sample weights in the `S` spaces + p : array-like, shape (N,) + Weights in the targeted barycenter + lambdas : list of float + List of the `S` spaces' weights + loss_fun : callable + tensor-matrix multiplication function based on specific loss function + symmetric : bool, optional. + Either structures are to be assumed symmetric or not. Default value is True. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + update : callable + function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates + :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings + calculated at each iteration + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on relative error (>0) + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. + init_C : bool | array-like, shape(N,N) + Random initial value for the :math:`\mathbf{C}` matrix provided by user. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + C : array-like, shape (`N`, `N`) + Similarity matrix in the barycenter space (permutated arbitrarily) + log : dict + Log dictionary of error during iterations. Return only if `log=True` in parameters. + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + p = list_to_array(p) + nx = get_backend(*Cs, *ps, p) + + S = len(Cs) + + # Initialization of C : random SPD matrix (if not provided by user) + if init_C is None: + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) + C = dist(xalea, xalea) + C /= C.max() + C = nx.from_numpy(C, type_as=p) + else: + C = init_C + + if loss_fun == 'kl_loss': + armijo = True + + cpt = 0 + err = 1 + + error = [] + + while (err > tol and cpt < max_iter): + Cprev = C + + T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, + max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)] + if loss_fun == 'square_loss': + C = update_square_loss(p, lambdas, T, Cs) + + elif loss_fun == 'kl_loss': + C = update_kl_loss(p, lambdas, T, Cs) + + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = nx.norm(C - Cprev) + error.append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + + if log: + return C, {"err": error} + else: + return C + + +def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, + p=None, loss_fun='square_loss', armijo=False, symmetric=True, max_iter=100, tol=1e-9, + verbose=False, log=False, init_C=None, init_X=None, random_state=None, **kwargs): + r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` + + Parameters + ---------- + N : int + Desired number of samples of the target barycenter + Ys: list of array-like, each element has shape (ns,d) + Features of all samples + Cs : list of array-like, each element has shape (ns,ns) + Structure matrices of all samples + ps : list of array-like, each element has shape (ns,) + Masses of all samples. + lambdas : list of float + List of the `S` spaces' weights + alpha : float + Alpha parameter for the fgw distance + fixed_structure : bool + Whether to fix the structure of the barycenter during the updates + fixed_features : bool + Whether to fix the feature of the barycenter during the updates + loss_fun : str + Loss function used for the solver either 'square_loss' or 'kl_loss' + symmetric : bool, optional + Either structures are to be assumed symmetric or not. Default value is True. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on relative error (>0) + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. + init_C : array-like, shape (N,N), optional + Initialization for the barycenters' structure matrix. If not set + a random init is used. + init_X : array-like, shape (N,d), optional + Initialization for the barycenters' features. If not set a + random init is used. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + X : array-like, shape (`N`, `d`) + Barycenters' features + C : array-like, shape (`N`, `N`) + Barycenters' structure matrix + log : dict + Only returned when log=True. It contains the keys: + + - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices + - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`) + + + .. _references-fgw-barycenters: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + Ys = list_to_array(*Ys) + p = list_to_array(p) + nx = get_backend(*Cs, *Ys, *ps) + + S = len(Cs) + d = Ys[0].shape[1] # dimension on the node features + if p is None: + p = nx.ones(N, type_as=Cs[0]) / N + + if fixed_structure: + if init_C is None: + raise UndefinedParameter('If C is fixed it must be initialized') + else: + C = init_C + else: + if init_C is None: + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) + C = dist(xalea, xalea) + C = nx.from_numpy(C, type_as=ps[0]) + else: + C = init_C + + if fixed_features: + if init_X is None: + raise UndefinedParameter('If X is fixed it must be initialized') + else: + X = init_X + else: + if init_X is None: + X = nx.zeros((N, d), type_as=ps[0]) + else: + X = init_X + + T = [nx.outer(p, q) for q in ps] + + Ms = [dist(X, Ys[s]) for s in range(len(Ys))] + + if loss_fun == 'kl_loss': + armijo = True + + cpt = 0 + err_feature = 1 + err_structure = 1 + + if log: + log_ = {} + log_['err_feature'] = [] + log_['err_structure'] = [] + log_['Ts_iter'] = [] + + while ((err_feature > tol or err_structure > tol) and cpt < max_iter): + Cprev = C + Xprev = X + + if not fixed_features: + Ys_temp = [y.T for y in Ys] + X = update_feature_matrix(lambdas, Ys_temp, T, p).T + + Ms = [dist(X, Ys[s]) for s in range(len(Ys))] + + if not fixed_structure: + if loss_fun == 'square_loss': + T_temp = [t.T for t in T] + C = update_structure_matrix(p, lambdas, T_temp, Cs) + + T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, + max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)] + + # T is N,ns + err_feature = nx.norm(X - nx.reshape(Xprev, (N, d))) + err_structure = nx.norm(C - Cprev) + if log: + log_['err_feature'].append(err_feature) + log_['err_structure'].append(err_structure) + log_['Ts_iter'].append(T) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err_structure)) + print('{:5d}|{:8e}|'.format(cpt, err_feature)) + + cpt += 1 + + if log: + log_['T'] = T # from target to Ys + log_['p'] = p + log_['Ms'] = Ms + + if log: + return X, C, log_ + else: + return X, C + + +def update_structure_matrix(p, lambdas, T, Cs): + r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. + + It is calculated at each iteration + + Parameters + ---------- + p : array-like, shape (N,) + Masses in the targeted barycenter. + lambdas : list of float + List of the `S` spaces' weights. + T : list of S array-like of shape (ns, N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape (ns, ns) + Metric cost matrices. + + Returns + ------- + C : array-like, shape (`nt`, `nt`) + Updated :math:`\mathbf{C}` matrix. + """ + p = list_to_array(p) + T = list_to_array(*T) + Cs = list_to_array(*Cs) + nx = get_backend(*Cs, *T, p) + + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + return tmpsum / ppt + + +def update_feature_matrix(lambdas, Ys, Ts, p): + r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. + + + See "Solving the barycenter problem with Block Coordinate Descent (BCD)" + in :ref:`[24] ` calculated at each iteration + + Parameters + ---------- + p : array-like, shape (N,) + masses in the targeted barycenter + lambdas : list of float + List of the `S` spaces' weights + Ts : list of S array-like, shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration + Ys : list of S array-like, shape (d,ns) + The features. + + Returns + ------- + X : array-like, shape (`d`, `N`) + + + .. _references-update-feature-matrix: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + p = list_to_array(p) + Ts = list_to_array(*Ts) + Ys = list_to_array(*Ys) + nx = get_backend(*Ys, *Ts, p) + + p = 1. / p + tmpsum = sum([ + lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :] + for s in range(len(Ts)) + ]) + return tmpsum diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py new file mode 100644 index 0000000..638bb1c --- /dev/null +++ b/ot/gromov/_semirelaxed.py @@ -0,0 +1,543 @@ +# -*- coding: utf-8 -*- +""" +Semi-relaxed Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers. +""" + +# Author: Rémi Flamary +# Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np + + +from ..utils import list_to_array, unif +from ..optim import semirelaxed_cg, solve_1d_linesearch_quad +from ..backend import get_backend + +from ._utils import init_matrix_semirelaxed, gwloss, gwggrad + + +def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Returns the semi-relaxed gromov-wasserstein divergence transport from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` + + The function solves the following optimization problem: + + .. math:: + \mathbf{srGW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma} &\geq 0 + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + + - `L`: loss function to account for the misfit between the similarity matrices + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Coupling between the two spaces that minimizes: + + :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}` + log : dict + Convergence information and loss. + + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + if loss_fun == 'kl_loss': + raise NotImplementedError() + p = list_to_array(p) + if G0 is None: + nx = get_backend(p, C1, C2) + else: + nx = get_backend(p, C1, C2, G0) + + if symmetric is None: + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + if G0 is None: + q = unif(C2.shape[0], type_as=p) + G0 = nx.outer(p, q) + else: + q = nx.sum(G0, 0) + # Check first marginal of G0 + np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08) + + constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx) + + ones_p = nx.ones(p.shape[0], type_as=p) + + def f(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + return gwloss(constC + marginal_product, hC1, hC2, G, nx) + + if symmetric: + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + return gwggrad(constC + marginal_product, hC1, hC2, G, nx) + else: + constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx) + + def df(G): + qG = nx.sum(G, 0) + marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t)) + marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) + return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, M=0., reg=1., nx=nx, **kwargs) + + if log: + res, log = semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + log['srgw_dist'] = log['loss'][-1] + return res, log + else: + return semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + + +def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Returns the semi-relaxed gromov-wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` + + The function solves the following optimization problem: + + .. math:: + srGW = \min_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma} &\geq 0 + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - `L`: loss function to account for the misfit between the similarity + matrices + + Note that when using backends, this loss function is differentiable wrt the + matrices (C1, C2) but not yet for the weights p. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space. + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + srgw : float + Semi-relaxed Gromov-Wasserstein divergence + log : dict + convergence information and Coupling matrix + + References + ---------- + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + nx = get_backend(p, C1, C2) + + T, log_srgw = semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun, symmetric, log=True, G0=G0, + max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) + + q = nx.sum(T, 0) + log_srgw['T'] = T + srgw = log_srgw['srgw_dist'] + + if loss_fun == 'square_loss': + gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) + gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) + srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2)) + + if log: + return srgw, log_srgw + else: + return srgw + + +def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Computes the semi-relaxed FGW transport between two graphs (see :ref:`[48] `) + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma} &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` source weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[48] ` + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix representative of the structure in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix representative of the structure in the target space + p : array-like, shape (ns,) + Distribution in the source space + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + log : bool, optional + record log if True + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + gamma : array-like, shape (`ns`, `nt`) + Optimal transportation matrix for the given parameters. + log : dict + Log dictionary return only if log==True in parameters. + + + .. _references-semirelaxed-fused-gromov-wasserstein: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + if loss_fun == 'kl_loss': + raise NotImplementedError() + + p = list_to_array(p) + if G0 is None: + nx = get_backend(p, C1, C2, M) + else: + nx = get_backend(p, C1, C2, M, G0) + + if symmetric is None: + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + + if G0 is None: + q = unif(C2.shape[0], type_as=p) + G0 = nx.outer(p, q) + else: + q = nx.sum(G0, 0) + # Check marginals of G0 + np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08) + + constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx) + + ones_p = nx.ones(p.shape[0], type_as=p) + + def f(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + return gwloss(constC + marginal_product, hC1, hC2, G, nx) + + if symmetric: + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + return gwggrad(constC + marginal_product, hC1, hC2, G, nx) + else: + constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx) + + def df(G): + qG = nx.sum(G, 0) + marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t)) + marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) + return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return solve_semirelaxed_gromov_linesearch( + G, deltaG, cost_G, C1, C2, ones_p, M=(1 - alpha) * M, reg=alpha, nx=nx, **kwargs) + + if log: + res, log = semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + log['srfgw_dist'] = log['loss'][-1] + return res, log + else: + return semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + + +def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Computes the semi-relaxed FGW divergence between two graphs (see :ref:`[48] `) + + .. math:: + \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma} &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` source weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + The algorithm used for solving the problem is conditional gradient as + discussed in :ref:`[48] ` + + Note that when using backends, this loss function is differentiable wrt the + matrices (C1, C2) but not yet for the weights p. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix representative of the structure in the source space. + C2 : array-like, shape (nt, nt) + Metric cost matrix representative of the structure in the target space. + p : array-like, shape (ns,) + Distribution in the source space. + loss_fun : str, optional + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + log : bool, optional + Record log if True. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + Parameters can be directly passed to the ot.optim.cg solver. + + Returns + ------- + srfgw-divergence : float + Semi-relaxed Fused gromov wasserstein divergence for the given parameters. + log : dict + Log dictionary return only if log==True in parameters. + + + .. _references-semirelaxed-fused-gromov-wasserstein2: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + nx = get_backend(p, C1, C2, M) + + T, log_fgw = semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p, loss_fun, symmetric, alpha, G0, log=True, + max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) + q = nx.sum(T, 0) + srfgw_dist = log_fgw['srfgw_dist'] + log_fgw['T'] = T + + if loss_fun == 'square_loss': + gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) + gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) + srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M), + (alpha * gC1, alpha * gC2, (1 - alpha) * T)) + + if log: + return srfgw_dist, log_fgw + else: + return srfgw_dist + + +def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, + M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs): + """ + Solve the linesearch in the FW iterations + + Parameters + ---------- + + G : array-like, shape(ns,nt) + The transport map at a given iteration of the FW + deltaG : array-like (ns,nt) + Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration + cost_G : float + Value of the cost at `G` + C1 : array-like (ns,ns) + Structure matrix in the source domain. + C2 : array-like (nt,nt) + Structure matrix in the target domain. + ones_p: array-like (ns,1) + Array of ones of size ns + M : array-like (ns,nt) + Cost matrix between the features. + reg : float + Regularization parameter. + alpha_min : float, optional + Minimum value for alpha + alpha_max : float, optional + Maximum value for alpha + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + alpha : float + The optimal step size of the FW + fc : int + nb of function call. Useless here + cost_G : float + The value of the cost for the next iteration + + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2021. + """ + if nx is None: + G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M) + + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(G, deltaG, C1, C2) + else: + nx = get_backend(G, deltaG, C1, C2, M) + + qG, qdeltaG = nx.sum(G, 0), nx.sum(deltaG, 0) + dot = nx.dot(nx.dot(C1, deltaG), C2.T) + C2t_square = C2.T ** 2 + dot_qG = nx.dot(nx.outer(ones_p, qG), C2t_square) + dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), C2t_square) + a = reg * nx.sum((dot_qdeltaG - 2 * dot) * deltaG) + b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - 2 * dot) * G) + nx.sum((dot_qG - 2 * nx.dot(nx.dot(C1, G), C2.T)) * deltaG)) + alpha = solve_1d_linesearch_quad(a, b) + if alpha_min is not None or alpha_max is not None: + alpha = np.clip(alpha, alpha_min, alpha_max) + + # the new cost can be deduced from the line search quadratic function + cost_G = cost_G + a * (alpha ** 2) + b * alpha + + return alpha, 1, cost_G diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py new file mode 100644 index 0000000..e842250 --- /dev/null +++ b/ot/gromov/_utils.py @@ -0,0 +1,413 @@ +# -*- coding: utf-8 -*- +""" +Gromov-Wasserstein and Fused-Gromov-Wasserstein utils. +""" + +# Author: Erwan Vautier +# Nicolas Courty +# Rémi Flamary +# Titouan Vayer +# Cédric Vincent-Cuaz +# +# License: MIT License + + +from ..utils import list_to_array +from ..backend import get_backend + + +def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): + r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation + + Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the + selected loss function as the loss function of Gromow-Wasserstein discrepancy. + + The matrices are computed as described in Proposition 1 in :ref:`[12] ` + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{T}`: A coupling between those two spaces + + The square-loss function :math:`L(a, b) = |a - b|^2` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a^2 + + f_2(b) &= b^2 + + h_1(a) &= a + + h_2(b) &= 2b + + The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a \log(a) - a + + f_2(b) &= b + + h_1(a) &= a + + h_2(b) &= \log(b) + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Probability distribution in the source space + q : array-like, shape (nt,) + Probability distribution in the target space + loss_fun : str, optional + Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss') + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + + + .. _references-init-matrix: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + if nx is None: + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + if loss_fun == 'square_loss': + def f1(a): + return (a**2) + + def f2(b): + return (b**2) + + def h1(a): + return a + + def h2(b): + return 2 * b + elif loss_fun == 'kl_loss': + def f1(a): + return a * nx.log(a + 1e-15) - a + + def f2(b): + return b + + def h1(a): + return a + + def h2(b): + return nx.log(b + 1e-15) + + constC1 = nx.dot( + nx.dot(f1(C1), nx.reshape(p, (-1, 1))), + nx.ones((1, len(q)), type_as=q) + ) + constC2 = nx.dot( + nx.ones((len(p), 1), type_as=p), + nx.dot(nx.reshape(q, (1, -1)), f2(C2).T) + ) + constC = constC1 + constC2 + hC1 = h1(C1) + hC2 = h2(C2) + + return constC, hC1, hC2 + + +def tensor_product(constC, hC1, hC2, T, nx=None): + r"""Return the tensor for Gromov-Wasserstein fast computation + + The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` + + Parameters + ---------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + tens : array-like, shape (`ns`, `nt`) + :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result + + + .. _references-tensor-product: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + if nx is None: + constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T) + nx = get_backend(constC, hC1, hC2, T) + + A = - nx.dot( + nx.dot(hC1, T), hC2.T + ) + tens = constC + A + # tens -= tens.min() + return tens + + +def gwloss(constC, hC1, hC2, T, nx=None): + r"""Return the Loss for Gromov-Wasserstein + + The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` + + Parameters + ---------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + loss : float + Gromov Wasserstein loss + + + .. _references-gwloss: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + + tens = tensor_product(constC, hC1, hC2, T, nx) + if nx is None: + tens, T = list_to_array(tens, T) + nx = get_backend(tens, T) + + return nx.sum(tens * T) + + +def gwggrad(constC, hC1, hC2, T, nx=None): + r"""Return the gradient for Gromov-Wasserstein + + The gradient is computed as described in Proposition 2 in :ref:`[12] ` + + Parameters + ---------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + grad : array-like, shape (`ns`, `nt`) + Gromov Wasserstein gradient + + + .. _references-gwggrad: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + return 2 * tensor_product(constC, hC1, hC2, + T, nx) # [12] Prop. 2 misses a 2 factor + + +def update_square_loss(p, lambdas, T, Cs): + r""" + Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` + couplings calculated at each iteration + + Parameters + ---------- + p : array-like, shape (N,) + Masses in the targeted barycenter. + lambdas : list of float + List of the `S` spaces' weights. + T : list of S array-like of shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape(ns,ns) + Metric cost matrices. + + Returns + ---------- + C : array-like, shape (`nt`, `nt`) + Updated :math:`\mathbf{C}` matrix. + """ + T = list_to_array(*T) + Cs = list_to_array(*Cs) + p = list_to_array(p) + nx = get_backend(p, *T, *Cs) + + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + + return tmpsum / ppt + + +def update_kl_loss(p, lambdas, T, Cs): + r""" + Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration + + + Parameters + ---------- + p : array-like, shape (N,) + Weights in the targeted barycenter. + lambdas : list of float + List of the `S` spaces' weights + T : list of S array-like of shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape(ns,ns) + Metric cost matrices. + + Returns + ---------- + C : array-like, shape (`ns`, `ns`) + updated :math:`\mathbf{C}` matrix + """ + Cs = list_to_array(*Cs) + T = list_to_array(*T) + p = list_to_array(p) + nx = get_backend(p, *T, *Cs) + + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + + return nx.exp(tmpsum / ppt) + + +def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): + r"""Return loss matrices and tensors for semi-relaxed Gromov-Wasserstein fast computation + + Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the + selected loss function as the loss function of semi-relaxed Gromow-Wasserstein discrepancy. + + The matrices are computed as described in Proposition 1 in :ref:`[12] ` + and adapted to the semi-relaxed problem where the second marginal is not a constant anymore. + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{T}`: A coupling between those two spaces + + The square-loss function :math:`L(a, b) = |a - b|^2` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a^2 + + f_2(b) &= b^2 + + h_1(a) &= a + + h_2(b) &= 2b + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + T : array-like, shape (ns, nt) + Coupling between source and target spaces + p : array-like, shape (ns,) + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) adapted to srGW + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + fC2t: array-like, shape (nt, nt) + :math:`\mathbf{f2}(\mathbf{C2})^\top` matrix in Eq. (6) + + + .. _references-init-matrix: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + if nx is None: + C1, C2, p = list_to_array(C1, C2, p) + nx = get_backend(C1, C2, p) + + if loss_fun == 'square_loss': + def f1(a): + return (a**2) + + def f2(b): + return (b**2) + + def h1(a): + return a + + def h2(b): + return 2 * b + + constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))), + nx.ones((1, C2.shape[0]), type_as=p)) + + hC1 = h1(C1) + hC2 = h2(C2) + fC2t = f2(C2).T + return constC, hC1, hC2, fC2t diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 7d0640f..2ff02ab 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Solvers for the original linear program OT problem +Solvers for the original linear program OT problem. """ diff --git a/ot/optim.py b/ot/optim.py index 5a1d605..201f898 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- """ -Generic solvers for regularized OT +Generic solvers for regularized OT or its semi-relaxed version. """ # Author: Remi Flamary # Titouan Vayer -# +# Cédric Vincent-Cuaz # License: MIT License import numpy as np @@ -27,7 +27,7 @@ with warnings.catch_warnings(): def line_search_armijo( f, xk, pk, gfk, old_fval, args=(), c1=1e-4, - alpha0=0.99, alpha_min=None, alpha_max=None + alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs ): r""" Armijo linesearch function that works with matrices @@ -57,7 +57,8 @@ def line_search_armijo( minimum value for alpha alpha_max : float, optional maximum value for alpha - + nx : backend, optional + If let to its default value None, a backend test will be conducted. Returns ------- alpha : float @@ -68,9 +69,9 @@ def line_search_armijo( loss value at step alpha """ - - xk, pk, gfk = list_to_array(xk, pk, gfk) - nx = get_backend(xk, pk) + if nx is None: + xk, pk, gfk = list_to_array(xk, pk, gfk) + nx = get_backend(xk, pk) if len(xk.shape) == 0: xk = nx.reshape(xk, (-1,)) @@ -98,97 +99,38 @@ def line_search_armijo( return float(alpha), fc[0], phi1 -def solve_linesearch( - cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None, - reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None -): - """ - Solve the linesearch in the FW iterations - - Parameters - ---------- - cost : method - Cost in the FW for the linesearch - G : array-like, shape(ns,nt) - The transport map at a given iteration of the FW - deltaG : array-like (ns,nt) - Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration - Mi : array-like (ns,nt) - Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost - f_val : float - Value of the cost at `G` - armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. - C1 : array-like (ns,ns), optional - Structure matrix in the source domain. Only used and necessary when armijo=False - C2 : array-like (nt,nt), optional - Structure matrix in the target domain. Only used and necessary when armijo=False - reg : float, optional - Regularization parameter. Only used and necessary when armijo=False - Gc : array-like (ns,nt) - Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False - constC : array-like (ns,nt) - Constant for the gromov cost. See :ref:`[24] `. Only used and necessary when armijo=False - M : array-like (ns,nt), optional - Cost matrix between the features. Only used and necessary when armijo=False - alpha_min : float, optional - Minimum value for alpha - alpha_max : float, optional - Maximum value for alpha - - Returns - ------- - alpha : float - The optimal step size of the FW - fc : int - nb of function call. Useless here - f_val : float - The value of the cost for the next iteration +def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None, + numItermax=200, stopThr=1e-9, + stopThr2=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the general regularized OT problem or its semi-relaxed version with + conditional gradient or generalized conditional gradient depending on the + provided linear program solver. + The function solves the following optimization problem if set as a conditional gradient: - .. _references-solve-linesearch: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - """ - if armijo: - alpha, fc, f_val = line_search_armijo( - cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max - ) - else: # requires symetric matrices - G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M) - if isinstance(M, int) or isinstance(M, float): - nx = get_backend(G, deltaG, C1, C2, constC) - else: - nx = get_backend(G, deltaG, C1, C2, constC, M) + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_1} \cdot f(\gamma) - dot = nx.dot(nx.dot(C1, deltaG), C2) - a = -2 * reg * nx.sum(dot * deltaG) - b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG)) - c = cost(G) + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - alpha = solve_1d_linesearch_quad(a, b, c) - if alpha_min is not None or alpha_max is not None: - alpha = np.clip(alpha, alpha_min, alpha_max) - fc = None - f_val = cost(G + alpha * deltaG) + \gamma^T \mathbf{1} &= \mathbf{b} (optional constraint) - return alpha, fc, f_val + \gamma &\geq 0 + where : + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) -def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, - stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): - r""" - Solve the general regularized OT problem with conditional gradient + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` - The function solves the following optimization problem: + The function solves the following optimization problem if set a generalized conditional gradient: .. math:: \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg} \cdot f(\gamma) + \mathrm{reg_1}\cdot f(\gamma) + \mathrm{reg_2}\cdot\Omega(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} @@ -197,29 +139,39 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, \gamma &\geq 0 where : - - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - - :math:`f` is the regularization term (and `df` is its gradient) - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - - The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] ` Parameters ---------- a : array-like, shape (ns,) samples weights in the source domain b : array-like, shape (nt,) - samples in the target domain + samples weights in the target domain M : array-like, shape (ns, nt) loss matrix - reg : float + f : function + Regularization function taking a transportation matrix as argument + df: function + Gradient of the regularization function taking a transportation matrix as argument + reg1 : float Regularization term >0 + reg2 : float, + Entropic Regularization term >0. Ignored if set to None. + lp_solver: function, + linear program solver for direction finding of the (generalized) conditional gradient. + If set to emd will solve the general regularized OT problem using cg. + If set to lp_semi_relaxed_OT will solve the general regularized semi-relaxed OT problem using cg. + If set to sinkhorn will solve the general regularized OT problem using generalized cg. + line_search: function, + Function to find the optimal step. Currently used instances are: + line_search_armijo (generic solver). solve_gromov_linesearch for (F)GW problem. + solve_semirelaxed_gromov_linesearch for sr(F)GW problem. gcg_linesearch for the Generalized cg. G0 : array-like, shape (ns,nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations - numItermaxEmd : int, optional - Max number of iterations for emd stopThr : float, optional Stop threshold on the relative variation (>0) stopThr2 : float, optional @@ -240,16 +192,20 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, .. _references-cg: + .. _references_gcg: References ---------- .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. + See Also -------- ot.lp.emd : Unregularized optimal ransport ot.bregman.sinkhorn : Entropic regularized optimal transport - """ a, b, M, G0 = list_to_array(a, b, M, G0) if isinstance(M, int) or isinstance(M, float): @@ -265,42 +221,45 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, if G0 is None: G = nx.outer(a, b) else: - G = G0 - - def cost(G): - return nx.sum(M * G) + reg * f(G) + # to not change G0 in place. + G = nx.copy(G0) - f_val = cost(G) + if reg2 is None: + def cost(G): + return nx.sum(M * G) + reg1 * f(G) + else: + def cost(G): + return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G)) + cost_G = cost(G) if log: - log['loss'].append(f_val) + log['loss'].append(cost_G) it = 0 if verbose: print('{:5s}|{:12s}|{:8s}|{:8s}'.format( 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, 0, 0)) while loop: it += 1 - old_fval = f_val - + old_cost_G = cost_G # problem linearization - Mi = M + reg * df(G) + Mi = M + reg1 * df(G) + + if not (reg2 is None): + Mi = Mi + reg2 * (1 + nx.log(G)) # set M positive - Mi += nx.min(Mi) + Mi = Mi + nx.min(Mi) # solve linear program - Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True) + Gc, innerlog_ = lp_solver(a, b, Mi, **kwargs) + # line search deltaG = Gc - G - # line search - alpha, fc, f_val = solve_linesearch( - cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, - alpha_min=0., alpha_max=1., **kwargs - ) + alpha, fc, cost_G = line_search(cost, G, deltaG, Mi, cost_G, **kwargs) G = G + alpha * deltaG @@ -308,29 +267,197 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, if it >= numItermax: loop = 0 - abs_delta_fval = abs(f_val - old_fval) - relative_delta_fval = abs_delta_fval / abs(f_val) - if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: + abs_delta_cost_G = abs(cost_G - old_cost_G) + relative_delta_cost_G = abs_delta_cost_G / abs(cost_G) + if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2: loop = 0 if log: - log['loss'].append(f_val) + log['loss'].append(cost_G) if verbose: if it % 20 == 0: print('{:5s}|{:12s}|{:8s}|{:8s}'.format( 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, relative_delta_cost_G, abs_delta_cost_G)) if log: - log.update(logemd) + log.update(innerlog_) return G, log else: return G +def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, + numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9, + verbose=False, log=False, **kwargs): + r""" + Solve the general regularized OT problem with conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot f(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` + + + Parameters + ---------- + a : array-like, shape (ns,) + samples weights in the source domain + b : array-like, shape (nt,) + samples in the target domain + M : array-like, shape (ns, nt) + loss matrix + reg : float + Regularization term >0 + G0 : array-like, shape (ns,nt), optional + initial guess (default is indep joint density) + line_search: function, + Function to find the optimal step. + Default is line_search_armijo. + numItermax : int, optional + Max number of iterations + numItermaxEmd : int, optional + Max number of iterations for emd + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + **kwargs : dict + Parameters for linesearch + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-cg: + References + ---------- + + .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + + See Also + -------- + ot.lp.emd : Unregularized optimal ransport + ot.bregman.sinkhorn : Entropic regularized optimal transport + + """ + + def lp_solver(a, b, M, **kwargs): + return emd(a, b, M, numItermaxEmd, log=True) + + return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, + numItermax=numItermax, stopThr=stopThr, + stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) + + +def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, + numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the general regularized and semi-relaxed OT problem with conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot f(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma &\geq 0 + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` + + + Parameters + ---------- + a : array-like, shape (ns,) + samples weights in the source domain + b : array-like, shape (nt,) + currently estimated samples weights in the target domain + M : array-like, shape (ns, nt) + loss matrix + reg : float + Regularization term >0 + G0 : array-like, shape (ns,nt), optional + initial guess (default is indep joint density) + line_search: function, + Function to find the optimal step. + Default is the armijo line-search. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + **kwargs : dict + Parameters for linesearch + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-cg: + References + ---------- + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2021. + + """ + + nx = get_backend(a, b) + + def lp_solver(a, b, Mi, **kwargs): + # get minimum by rows as binary mask + Gc = nx.ones(1, type_as=a) * (Mi == nx.reshape(nx.min(Mi, axis=1), (-1, 1))) + Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1)) + # return by default an empty inner_log + return Gc, {} + + return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, + numItermax=numItermax, stopThr=stopThr, + stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) + + def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, - numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False): + numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): r""" Solve the general regularized OT problem with the generalized conditional gradient @@ -403,81 +530,18 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, ot.optim.cg : conditional gradient """ - a, b, M, G0 = list_to_array(a, b, M, G0) - nx = get_backend(a, b, M) - - loop = 1 - - if log: - log = {'loss': []} - - if G0 is None: - G = nx.outer(a, b) - else: - G = G0 - - def cost(G): - return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G) - - f_val = cost(G) - if log: - log['loss'].append(f_val) - - it = 0 - - if verbose: - print('{:5s}|{:12s}|{:8s}|{:8s}'.format( - 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) - while loop: - - it += 1 - old_fval = f_val - - # problem linearization - Mi = M + reg2 * df(G) - - # solve linear program with Sinkhorn - # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax) - Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax) - - deltaG = Gc - G - - # line search - dcost = Mi + reg1 * (1 + nx.log(G)) # ?? - alpha, fc, f_val = line_search_armijo( - cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1. - ) - - G = G + alpha * deltaG - - # test convergence - if it >= numItermax: - loop = 0 - - abs_delta_fval = abs(f_val - old_fval) - relative_delta_fval = abs_delta_fval / abs(f_val) + def lp_solver(a, b, Mi, **kwargs): + return sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax, log=True, **kwargs) - if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: - loop = 0 + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs) - if log: - log['loss'].append(f_val) + return generic_conditional_gradient(a, b, M, f, df, reg2, reg1, lp_solver, line_search, G0=G0, + numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) - if verbose: - if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}|{:8s}'.format( - 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) - if log: - return G, log - else: - return G - - -def solve_1d_linesearch_quad(a, b, c): +def solve_1d_linesearch_quad(a, b): r""" For any convex or non-convex 1d quadratic function `f`, solve the following problem: @@ -487,7 +551,7 @@ def solve_1d_linesearch_quad(a, b, c): Parameters ---------- - a,b,c : float + a,b : float or tensors (1,) The coefficients of the quadratic function Returns @@ -495,15 +559,11 @@ def solve_1d_linesearch_quad(a, b, c): x : float The optimal value which leads to the minimal cost """ - f0 = c - df0 = b - f1 = a + f0 + df0 - if a > 0: # convex - minimum = min(1, max(0, np.divide(-b, 2.0 * a))) + minimum = min(1., max(0., -b / (2.0 * a))) return minimum else: # non convex - if f0 > f1: - return 1 + if a + b < 0: + return 1. else: - return 0 + return 0. diff --git a/test/test_gromov.py b/test/test_gromov.py index 9c85b92..cfccce7 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -3,7 +3,7 @@ # Author: Erwan Vautier # Nicolas Courty # Titouan Vayer -# Cédric Vincent-Cuaz +# Cédric Vincent-Cuaz # # License: MIT License @@ -11,18 +11,15 @@ import numpy as np import ot from ot.backend import NumpyBackend from ot.backend import torch, tf - import pytest def test_gromov(nx): n_samples = 50 # nb samples - mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) - + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) xt = xs[::-1].copy() p = ot.unif(n_samples) @@ -38,7 +35,7 @@ def test_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True) - Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=G0b, verbose=True)) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -51,13 +48,13 @@ def test_gromov(nx): np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) - gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True) - gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True) + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, log=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=True) gwb = nx.to_numpy(gwb) - gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', G0=G0, log=False) + gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False) gw_valb = nx.to_numpy( - ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) + ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) ) G = log['T'] @@ -77,6 +74,49 @@ def test_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_asymmetric_gromov(nx): + n_samples = 30 # nb samples + np.random.seed(0) + C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + np.random.shuffle(idx) + C2 = C1[idx, :][:, idx] + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + G, log = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True) + Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True) + Gb = nx.to_numpy(Gb) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04) + + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04) + + def test_gromov_dtype_device(nx): # setup n_samples = 50 # nb samples @@ -104,7 +144,7 @@ def test_gromov_dtype_device(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -130,7 +170,7 @@ def test_gromov_device_tf(): with tf.device("/CPU:0"): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -138,7 +178,7 @@ def test_gromov_device_tf(): # Check that everything happens on the GPU C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) assert nx.dtype_device(Gb)[1].startswith("GPU") @@ -185,6 +225,45 @@ def test_gromov2_gradients(): assert C12.shape == C12.grad.shape +def test_gw_helper_backend(nx): + n_samples = 20 # nb samples + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', armijo=False, symmetric=True, G0=G0b, log=True) + + # calls with nx=None + constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss') + + def f(G): + return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None) + + def df(G): + return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=0., reg=1., nx=None) + # feed the precomputed local optimum Gb to cg + res, log = ot.optim.cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) + + @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_gromov(nx): @@ -199,19 +278,21 @@ def test_entropic_gromov(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) - + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) C1 /= C1.max() C2 /= C2.max() - C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - G = ot.gromov.entropic_gromov_wasserstein( - C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) + G, log = ot.gromov.entropic_gromov_wasserstein( + C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-2, verbose=True, log=True) Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True + C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, + epsilon=1e-2, verbose=True, log=False )) # check constraints @@ -222,9 +303,11 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov gw, log = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', max_iter=10, epsilon=1e-2, log=True) + C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, + max_iter=10, epsilon=1e-2, log=True) gwb, logb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, 'kl_loss', max_iter=10, epsilon=1e-2, log=True) + C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=10, epsilon=1e-2, log=True) gwb = nx.to_numpy(gwb) G = log['T'] @@ -241,6 +324,45 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_asymmetric_entropic_gromov(nx): + n_samples = 10 # nb samples + np.random.seed(0) + C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + np.random.shuffle(idx) + C2 = C1[idx, :][:, idx] + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + G = ot.gromov.entropic_gromov_wasserstein( + C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, verbose=True, log=False) + Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None, + epsilon=1e-1, verbose=True, log=False + )) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + gw = ot.gromov.entropic_gromov_wasserstein2( + C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, + max_iter=10, epsilon=1e-1, log=False) + gwb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=10, epsilon=1e-1, log=False) + gwb = nx.to_numpy(gwb) + + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) + + @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_gromov_dtype_device(nx): @@ -539,8 +661,8 @@ def test_fgw(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True) Gb = nx.to_numpy(Gb) # check constraints @@ -555,8 +677,8 @@ def test_fgw(nx): np.testing.assert_allclose( Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov - fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', G0=None, alpha=0.5, log=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', G0=G0b, alpha=0.5, log=True) + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True) fgwb = nx.to_numpy(fgwb) G = log['T'] @@ -573,6 +695,82 @@ def test_fgw(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_asymmetric_fgw(nx): + n_samples = 50 # nb samples + np.random.seed(0) + C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + np.random.shuffle(idx) + C2 = C1[idx, :][:, idx] + + # add features + F1 = np.random.uniform(low=0., high=10, size=(n_samples, 1)) + F2 = F1[idx, :] + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + M = ot.dist(F1, F2) + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True) + Gb = nx.to_numpy(Gb) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + # Tests with kl-loss: + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True) + Gb = nx.to_numpy(Gb) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + def test_fgw2_gradients(): n_samples = 20 # nb samples @@ -617,6 +815,57 @@ def test_fgw2_gradients(): assert M1.shape == M1.grad.shape +def test_fgw_helper_backend(nx): + n_samples = 20 # nb samples + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + ys = np.random.randn(xs.shape[0], 2) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + yt = np.random.randn(xt.shape[0], 2) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + M /= M.max() + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + alpha = 0.5 + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True) + + # calls with nx=None + constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss') + + def f(G): + return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None) + + def df(G): + return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=(1 - alpha) * Mb, reg=alpha, nx=None) + # feed the precomputed local optimum Gb to cg + res, log = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=None) + # feed the precomputed local optimum Gb to cg + res_armijo, log_armijo = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) + np.testing.assert_allclose(res_armijo, Gb, atol=1e-06) + + def test_fgw_barycenter(nx): np.random.seed(42) @@ -1186,3 +1435,327 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): # > Compare results with/without backend total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2) np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05) + + +def test_semirelaxed_gromov(nx): + np.random.seed(0) + # unbalanced proportions + list_n = [30, 15] + nt = 2 + ns = np.sum(list_n) + # create directed sbm with C2 as connectivity matrix + C1 = np.zeros((ns, ns), dtype=np.float64) + C2 = np.array([[0.8, 0.05], + [0.05, 1.]], dtype=np.float64) + for i in range(nt): + for j in range(nt): + ni, nj = list_n[i], list_n[j] + xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j]) + C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + p = ot.unif(ns, type_as=C1) + q0 = ot.unif(C2.shape[0], type_as=C1) + G0 = p[:, None] * q0[None, :] + # asymmetric + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) + + G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) + Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=False, log=True, G0=None) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + + srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + + # symmetric + C1 = 0.5 * (C1 + C1.T) + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) + + G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None) + Gb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + + srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=False, G0=G0) + + G = log2['T'] + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + + np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + + +def test_semirelaxed_gromov2_gradients(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + # semirelaxed solvers do not support gradients over masses yet. + p1 = torch.tensor(p, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + + val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1) + + val.backward() + + assert val.device == p1.device + assert p1.grad is None + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + + +def test_srgw_helper_backend(nx): + n_samples = 20 # nb samples + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) + Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, 'square_loss', armijo=False, symmetric=True, G0=None, log=True) + + # calls with nx=None + constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss') + ones_pb = nx.ones(pb.shape[0], type_as=pb) + + def f(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_semirelaxed_gromov_linesearch( + G, deltaG, cost_G, C1b, C2b, ones_pb, 0., 1., nx=None) + # feed the precomputed local optimum Gb to semirelaxed_cg + res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) + + +def test_semirelaxed_fgw(nx): + np.random.seed(0) + list_n = [16, 8] + nt = 2 + ns = 24 + # create directed sbm with C2 as connectivity matrix + C1 = np.zeros((ns, ns)) + C2 = np.array([[0.7, 0.05], + [0.05, 0.9]]) + for i in range(nt): + for j in range(nt): + ni, nj = list_n[i], list_n[j] + xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j]) + C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + F1 = np.zeros((ns, 1)) + F1[:16] = np.random.normal(loc=0., scale=0.01, size=(16, 1)) + F1[16:] = np.random.normal(loc=1., scale=0.01, size=(8, 1)) + F2 = np.zeros((2, 1)) + F2[1, :] = 1. + M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T) + + p = ot.unif(ns) + q0 = ot.unif(C2.shape[0]) + G0 = p[:, None] * q0[None, :] + + # asymmetric + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + + # symmetric + C1 = 0.5 * (C1 + C1.T) + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + + srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + + +def test_semirelaxed_fgw2_gradients(): + n_samples = 20 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + M = ot.dist(xs, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + # semirelaxed solvers do not support gradients over masses yet. + p1 = torch.tensor(p, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + + val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1) + + val.backward() + + assert val.device == p1.device + assert p1.grad is None + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape + + +def test_srfgw_helper_backend(nx): + n_samples = 20 # nb samples + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + ys = np.random.randn(xs.shape[0], 2) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + yt = np.random.randn(xt.shape[0], 2) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + M /= M.max() + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + alpha = 0.5 + Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True) + + # calls with nx=None + constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss') + ones_pb = nx.ones(pb.shape[0], type_as=pb) + + def f(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_semirelaxed_gromov_linesearch( + G, deltaG, cost_G, C1b, C2b, ones_pb, M=(1 - alpha) * Mb, reg=alpha, nx=None) + # feed the precomputed local optimum Gb to semirelaxed_cg + res, log = ot.optim.semirelaxed_cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) diff --git a/test/test_optim.py b/test/test_optim.py index 67e9d13..129fe22 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -120,15 +120,15 @@ def test_generalized_conditional_gradient(nx): Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True) Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(Gb, G, atol=1e-12) np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05) np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05) def test_solve_1d_linesearch_quad_funct(): - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1), 0.5) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5), 0) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5), 1) def test_line_search_armijo(nx): -- cgit v1.2.3