From ccc076e0fc535b2c734214c0ac1936e9e2cbeb62 Mon Sep 17 00:00:00 2001 From: eloitanguy <69361683+eloitanguy@users.noreply.github.com> Date: Fri, 6 May 2022 08:43:21 +0200 Subject: [WIP] Generalized Wasserstein Barycenters (#372) * GWB first solver version * tests + example for gwb (untested) + free_bar doc fix * improved doc, fixed minor bugs, better example visu * minor doc + visu fixes * plot GWB pep8 fix * fixed partial gromov test reproductibility * added an animation for the GWB visu * added PR num * minor doc fixes + better gwb logo --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) (limited to 'CONTRIBUTORS.md') diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index ab64fba..0909b14 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -37,6 +37,7 @@ The contributors to this library are: * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) +* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) ## Acknowledgments -- cgit v1.2.3 From 7c2a9523747c90aebfef711fdf34b5bbdb6f2f4d Mon Sep 17 00:00:00 2001 From: clecoz Date: Tue, 21 Jun 2022 17:36:22 +0200 Subject: [MRG] raise error if mass mismatch in emd2 (#386) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Two lines added in the function emd2 to ensure that the distributions have the same mass (same as it already was in the function emd). * The same mass test has been moved inside the function f(b) to be compatible with emd2 with multiple b. * Test added. The function test_emd_dimension_and_mass_mismatch (in test/test_ot.py) has been modified to check for mass mismatch with emd2. * Add PR in releases.md * Merge and add PR in releases.md * Add name in contributors.md * Correction contribution in contributors.md * Move test on mass outside of functions f(b) * Update doc of emd and emd2 Co-authored-by: Camille Le Coz Co-authored-by: Rémi Flamary --- CONTRIBUTORS.md | 1 + RELEASES.md | 1 + ot/lp/__init__.py | 9 +++++++++ test/test_ot.py | 3 +++ 4 files changed, 14 insertions(+) (limited to 'CONTRIBUTORS.md') diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 0909b14..c535c09 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -38,6 +38,7 @@ The contributors to this library are: * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) +* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) ## Acknowledgments diff --git a/RELEASES.md b/RELEASES.md index b384617..78a7d9e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -17,6 +17,7 @@ - Fixed an issue where pointers would overflow in the EMD solver, returning an incomplete transport plan above a certain size (slightly above 46k, its square being roughly 2^31) (PR #381) +- Error raised when mass mismatch in emd2 (PR #386) ## 0.8.2 diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 572781d..17411d0 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -230,6 +230,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): If this behaviour is unwanted, please make sure to provide a floating point input. + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + Uses the algorithm proposed in :ref:`[1] `. Parameters @@ -389,6 +391,8 @@ def emd2(a, b, M, processes=1, If this behaviour is unwanted, please make sure to provide a floating point input. + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + Uses the algorithm proposed in :ref:`[1] `. Parameters @@ -481,6 +485,11 @@ def emd2(a, b, M, processes=1, assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0), + b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum') + b = b * a.sum(0) / b.sum(0,keepdims=True) + asel = a != 0 numThreads = check_number_threads(numThreads) diff --git a/test/test_ot.py b/test/test_ot.py index ba3ef6a..9a4e175 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -29,9 +29,12 @@ def test_emd_dimension_and_mass_mismatch(): np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + # test emd and emd2 for mass mismatch + a = ot.utils.unif(n_samples) b = a.copy() a[0] = 100 np.testing.assert_raises(AssertionError, ot.emd, a, b, M) + np.testing.assert_raises(AssertionError, ot.emd2, a, b, M) def test_emd_backends(nx): -- cgit v1.2.3 From 818c7ace20da36d8042b0d7ad7a712b27f7afd59 Mon Sep 17 00:00:00 2001 From: Eduardo Fernandes Montesuma Date: Wed, 27 Jul 2022 11:16:14 +0200 Subject: [MRG] Free support Sinkhorn barycenters (#387) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Adding function for computing Sinkhorn Free Support barycenters * Adding exampel on Free Support Sinkhorn Barycenter * Fixing typo on free support sinkhorn barycenter example * Adding info on new Free Support Barycenter solver * Removing extra line so that code follows pep8 * Fixing issues with pep8 in example * Correcting issues with pep8 standards * Adding tests for free support sinkhorn barycenter * Adding section on Sinkhorn barycenter to the example * Changing distributions for the Sinkhorn barycenter example * Removing file that should not be on the last commit * Adding PR number to REALEASES.md * Adding new contributors * Update CONTRIBUTORS.md Co-authored-by: Rémi Flamary --- CONTRIBUTORS.md | 1 + RELEASES.md | 1 + .../barycenters/plot_free_support_barycenter.py | 28 +++- .../plot_free_support_sinkhorn_barycenter.py | 151 +++++++++++++++++++++ ot/bregman.py | 120 ++++++++++++++++ test/test_bregman.py | 26 ++++ 6 files changed, 324 insertions(+), 3 deletions(-) create mode 100644 examples/barycenters/plot_free_support_sinkhorn_barycenter.py (limited to 'CONTRIBUTORS.md') diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index c535c09..0524151 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -39,6 +39,7 @@ The contributors to this library are: * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) +* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) ## Acknowledgments diff --git a/RELEASES.md b/RELEASES.md index 78a7d9e..14d11c4 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### New features - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) +- Added Free Support Sinkhorn Barycenter + example (PR #387) #### Closed issues diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py index 226dfeb..f4a13dd 100644 --- a/examples/barycenters/plot_free_support_barycenter.py +++ b/examples/barycenters/plot_free_support_barycenter.py @@ -4,13 +4,14 @@ 2D free support Wasserstein barycenters of distributions ======================================================== -Illustration of 2D Wasserstein barycenters if distributions are weighted +Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted sum of diracs. """ # Authors: Vivien Seguy # Rémi Flamary +# Eduardo Fernandes Montesuma # # License: MIT License @@ -48,7 +49,7 @@ pl.title('Distributions') # %% -# Compute free support barycenter +# Compute free support Wasserstein barycenter # ------------------------------- k = 200 # number of Diracs of the barycenter @@ -58,7 +59,28 @@ b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, on X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b) # %% -# Plot the barycenter +# Plot the Wasserstein barycenter +# --------- + +pl.figure(2, (8, 3)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) +pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter') +pl.title('Data measures and their barycenter') +pl.legend(loc="lower right") +pl.show() + +# %% +# Compute free support Sinkhorn barycenter + +k = 200 # number of Diracs of the barycenter +X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations +b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) + +X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15) + +# %% +# Plot the Wasserstein barycenter # --------- pl.figure(2, (8, 3)) diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py new file mode 100644 index 0000000..ebe1f3b --- /dev/null +++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +""" +======================================================== +2D free support Sinkhorn barycenters of distributions +======================================================== + +Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds + +""" + +# Authors: Eduardo Fernandes Montesuma +# +# License: MIT License + +import numpy as np +import matplotlib.pyplot as plt +import ot + +# %% +# General Parameters +# ------------------ +reg = 1e-2 # Entropic Regularization +numItermax = 20 # Maximum number of iterations for the Barycenter algorithm +numInnerItermax = 50 # Maximum number of sinkhorn iterations +n_samples = 200 + +# %% +# Generate Data +# ------------- + +X1 = np.random.randn(200, 2) +X2 = 2 * np.concatenate([ + np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1), + np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1), +], axis=0) +X3 = np.random.randn(200, 2) +X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None]) +X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200) + +a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)) + +# %% +# Inspect generated distributions +# ------------------------------- + +fig, axes = plt.subplots(1, 4, figsize=(16, 4)) + +axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k') +axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k') +axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k') +axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k') + +axes[0].set_xlim([-3, 3]) +axes[0].set_ylim([-3, 3]) +axes[0].set_title('Distribution 1') + +axes[1].set_xlim([-3, 3]) +axes[1].set_ylim([-3, 3]) +axes[1].set_title('Distribution 2') + +axes[2].set_xlim([-3, 3]) +axes[2].set_ylim([-3, 3]) +axes[2].set_title('Distribution 3') + +axes[3].set_xlim([-3, 3]) +axes[3].set_ylim([-3, 3]) +axes[3].set_title('Distribution 4') + +plt.tight_layout() +plt.show() + +# %% +# Interpolating Empirical Distributions +# ------------------------------------- + +fig = plt.figure(figsize=(10, 10)) + +weights = np.array([ + [3 / 3, 0 / 3], + [2 / 3, 1 / 3], + [1 / 3, 2 / 3], + [0 / 3, 3 / 3], +]).astype(np.float32) + +for k in range(4): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X1, X2], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (0, k)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 4, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X1, X3], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (k, 0)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 4, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X3, X4], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (3, k)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 3, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X2, X4], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (k, 3)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +plt.show() diff --git a/ot/bregman.py b/ot/bregman.py index 34dcadb..b1321a4 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1540,6 +1540,126 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, return geometricBar(weights, UKv) +def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None, + numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None, + **kwargs): + r""" + Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally: + + .. math:: + \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) + + where : + + - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one + - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) + - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations + - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter + + This problem is considered in :ref:`[20] ` (Algorithm 2). + There are two differences with the following codes: + + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in + :ref:`[20] ` (Algorithm 2). This can be seen as a discrete + implementation of the fixed-point algorithm of + :ref:`[43] ` proposed in the continuous setting. + - at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the + transport plan in :ref:`[20] ` (Algorithm 2). + + Parameters + ---------- + measures_locations : list of N (k_i,d) array-like + The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space + (:math:`k_i` can be different for each element of the list) + measures_weights : list of N (k_i,) array-like + Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one + representing the weights of each discrete input measure + + X_init : (k,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + reg : float + Regularization term >0 + b : (k,) array-like + Initialization of the weights of the barycenter (non-negatives, sum to 1) + weights : (N,) array-like + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterations when calculating the transport plans with Sinkhorn + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + X : (k,d) array-like + Support locations (on k atoms) of the barycenter + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT solver + ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming + + .. _references-free-support-barycenter: + References + ---------- + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + nx = get_backend(*measures_locations, *measures_weights, X_init) + + iter_count = 0 + + N = len(measures_locations) + k = X_init.shape[0] + d = X_init.shape[1] + if b is None: + b = nx.ones((k,), type_as=X_init) / k + if weights is None: + weights = nx.ones((N,), type_as=X_init) / N + + X = X_init + + log_dict = {} + displacement_square_norms = [] + + displacement_square_norm = stopThr + 1. + + while (displacement_square_norm > stopThr and iter_count < numItermax): + + T_sum = nx.zeros((k, d), type_as=X_init) + + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): + M_i = dist(X, measure_locations_i) + T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs) + T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i) + + displacement_square_norm = nx.sum((T_sum - X) ** 2) + if log: + displacement_square_norms.append(displacement_square_norm) + + X = T_sum + + if verbose: + print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) + + iter_count += 1 + + if log: + log_dict['displacement_square_norms'] = displacement_square_norms + return X, log_dict + else: + return X + + def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic wasserstein barycenter in log-domain diff --git a/test/test_bregman.py b/test/test_bregman.py index 112bfca..e128ea2 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -3,6 +3,7 @@ # Author: Remi Flamary # Kilian Fatras # Quang Huy Tran +# Eduardo Fernandes Montesuma # # License: MIT License @@ -490,6 +491,31 @@ def test_barycenter(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, log=True) +def test_free_support_sinkhorn_barycenter(): + measures_locations = [ + np.array([-1.]).reshape((1, 1)), # First dirac support + np.array([1.]).reshape((1, 1)) # Second dirac support + ] + + measures_weights = [ + np.array([1.]), # First dirac sample weights + np.array([1.]) # Second dirac sample weights + ] + + # Barycenter initialization + X_init = np.array([-12.]).reshape((1, 1)) + + # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter + bar_locations = np.array([0.]).reshape((1, 1)) + + # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization + # term to 1, but this should be, in general, fine-tuned to the problem. + X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1) + + # Verifies if calculated barycenter matches ground-truth + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + @pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], [True, False], [True, False])) -- cgit v1.2.3 From 97feeb32b6c069d7bb44cd995531c2b820d59771 Mon Sep 17 00:00:00 2001 From: tgnassou <66993815+tgnassou@users.noreply.github.com> Date: Mon, 16 Jan 2023 18:09:44 +0100 Subject: [MRG] OT for Gaussian distributions (#428) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add gaussian modules * add gaussian modules * add PR to release.md * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * Update ot/gaussian.py * Update ot/gaussian.py * add empirical bures wassertsein distance, fix docstring and test * update to fit with new networkx API * add test for jax et tf" * fix test * fix test? * add empirical_bures_wasserstein_mapping * fix docs * fix doc * fix docstring * add tgnassou to contributors * add more coverage for gaussian.py * add deprecated function * fix doc math" " * fix doc math" " * add remi flamary to authors of gaussiansmodule * fix equation Co-authored-by: Rémi Flamary Co-authored-by: Alexandre Gramfort --- CONTRIBUTORS.md | 1 + RELEASES.md | 1 + docs/source/all.rst | 1 + docs/source/quickstart.rst | 6 +- .../domain-adaptation/plot_otda_linear_mapping.py | 2 +- examples/gromov/plot_barycenter_fgw.py | 2 +- ot/__init__.py | 3 +- ot/da.py | 118 +------- ot/gaussian.py | 333 +++++++++++++++++++++ test/test_da.py | 21 -- test/test_gaussian.py | 98 ++++++ 11 files changed, 448 insertions(+), 138 deletions(-) create mode 100644 ot/gaussian.py create mode 100644 test/test_gaussian.py (limited to 'CONTRIBUTORS.md') diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 0524151..67d8337 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -40,6 +40,7 @@ The contributors to this library are: * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) +* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) ## Acknowledgments diff --git a/RELEASES.md b/RELEASES.md index c78319d..4ed3625 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,6 +4,7 @@ #### New features +- Added Bures Wasserstein distance in `ot.gaussian` (PR ##428) - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) - Added Free Support Sinkhorn Barycenter + example (PR #387) - New API for OT solver using function `ot.solve` (PR #388) diff --git a/docs/source/all.rst b/docs/source/all.rst index 1ec6be3..60cc85c 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -31,6 +31,7 @@ API and modules sliced weak factored + gaussian .. autosummary:: :toctree: ../modules/generated/ diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index b4cc8ab..c8eac30 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -279,7 +279,7 @@ distributions. In this case there exists a close form solution given in Remark 2.29 in [15]_ and the Monge mapping is an affine function and can be also computed from the covariances and means of the source and target distributions. In the case when the finite sample dataset is supposed Gaussian, -we provide :any:`ot.da.OT_mapping_linear` that returns the parameters for the +we provide :any:`ot.gaussian.bures_wasserstein_mapping` that returns the parameters for the Monge mapping. @@ -628,7 +628,7 @@ approximate a Monge mapping from finite distributions. First note that when the source and target distributions are supposed to be Gaussian distributions, there exists a close form solution for the mapping and its an affine function [14]_ of the form :math:`T(x)=Ax+b` . In this case we provide the function -:any:`ot.da.OT_mapping_linear` that returns the operator :math:`A` and vector +:any:`ot.gaussian.bures_wasserstein_mapping` that returns the operator :math:`A` and vector :math:`b`. Note that if the number of samples is too small there is a parameter :code:`reg` that provides a regularization for the covariance matrix estimation. @@ -640,7 +640,7 @@ method proposed in [8]_ that estimates a continuous mapping approximating the barycentric mapping is provided in :any:`ot.da.joint_OT_mapping_linear` for linear mapping and :any:`ot.da.joint_OT_mapping_kernel` for non-linear mapping. -.. minigallery:: ot.da.joint_OT_mapping_linear ot.da.joint_OT_mapping_linear ot.da.OT_mapping_linear +.. minigallery:: ot.da.joint_OT_mapping_linear ot.da.joint_OT_mapping_linear ot.gaussian.bures_wasserstein_mapping :add-heading: Examples of Monge mapping estimation :heading-level: " diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py index a44096a..8284a2a 100644 --- a/examples/domain-adaptation/plot_otda_linear_mapping.py +++ b/examples/domain-adaptation/plot_otda_linear_mapping.py @@ -61,7 +61,7 @@ plt.plot(xt[:, 0], xt[:, 1], 'o') # Estimate linear mapping and transport # ------------------------------------- -Ae, be = ot.da.OT_mapping_linear(xs, xt) +Ae, be = ot.gaussian.empirical_bures_wasserstein_mapping(xs, xt) xst = xs.dot(Ae) + be diff --git a/examples/gromov/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py index 556e08f..dc3c6aa 100644 --- a/examples/gromov/plot_barycenter_fgw.py +++ b/examples/gromov/plot_barycenter_fgw.py @@ -174,7 +174,7 @@ A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True) # ------------------------- #%% Create the barycenter -bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) +bary = nx.from_numpy_array(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) for i, v in enumerate(A.ravel()): bary.add_node(i, attr_name=v) diff --git a/ot/__init__.py b/ot/__init__.py index 51eb726..0b55e0c 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,6 +35,7 @@ from . import regpath from . import weak from . import factored from . import solvers +from . import gaussian # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -56,7 +57,7 @@ __version__ = "0.8.3dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', - 'emd2_1d', 'wasserstein_1d', 'backend', + 'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', diff --git a/ot/da.py b/ot/da.py index 083663c..35e303b 100644 --- a/ot/da.py +++ b/ot/da.py @@ -17,8 +17,9 @@ from .backend import get_backend from .bregman import sinkhorn, jcpot_barycenter from .lp import emd from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots -from .utils import list_to_array, check_params, BaseEstimator +from .utils import list_to_array, check_params, BaseEstimator, deprecated from .unbalanced import sinkhorn_unbalanced +from .gaussian import empirical_bures_wasserstein_mapping from .optim import cg from .optim import gcg @@ -679,112 +680,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', return G, L -def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, - wt=None, bias=True, log=False): - r"""Return OT linear operator between samples. - - The function estimates the optimal linear operator that aligns the two - empirical distributions. This is equivalent to estimating the closed - form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` - and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in - :ref:`[14] ` and discussed in remark 2.29 in - :ref:`[15] `. - - The linear operator from source to target :math:`M` - - .. math:: - M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} - - where : - - .. math:: - \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} - \Sigma_s^{-1/2} - - \mathbf{b} &= \mu_t - \mathbf{A} \mu_s - - Parameters - ---------- - xs : array-like (ns,d) - samples in the source domain - xt : array-like (nt,d) - samples in the target domain - reg : float,optional - regularization added to the diagonals of covariances (>0) - ws : array-like (ns,1), optional - weights for the source samples - wt : array-like (ns,1), optional - weights for the target samples - bias: boolean, optional - estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) - log : bool, optional - record log if True - - - Returns - ------- - A : (d, d) array-like - Linear operator - b : (1, d) array-like - bias - log : dict - log dictionary return only if log==True in parameters - - - .. _references-OT-mapping-linear: - References - ---------- - .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of - distributions", Journal of Optimization Theory and Applications - Vol 43, 1984 - - .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - - """ - xs, xt = list_to_array(xs, xt) - nx = get_backend(xs, xt) - - d = xs.shape[1] - - if bias: - mxs = nx.mean(xs, axis=0)[None, :] - mxt = nx.mean(xt, axis=0)[None, :] - - xs = xs - mxs - xt = xt - mxt - else: - mxs = nx.zeros((1, d), type_as=xs) - mxt = nx.zeros((1, d), type_as=xs) - - if ws is None: - ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] - - if wt is None: - wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] - - Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) - Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) - - Cs12 = nx.sqrtm(Cs) - Cs_12 = nx.inv(Cs12) - - M0 = nx.sqrtm(dots(Cs12, Ct, Cs12)) - - A = dots(Cs_12, M0, Cs_12) - - b = mxt - nx.dot(mxs, A) - - if log: - log = {} - log['Cs'] = Cs - log['Ct'] = Ct - log['Cs12'] = Cs12 - log['Cs_12'] = Cs_12 - return A, b, log - else: - return A, b +OT_mapping_linear = deprecated(empirical_bures_wasserstein_mapping) def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5, @@ -1378,10 +1274,10 @@ class LinearTransport(BaseTransport): self.mu_t = self.distribution_estimation(Xt) # coupling estimation - returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg, - ws=nx.reshape(self.mu_s, (-1, 1)), - wt=nx.reshape(self.mu_t, (-1, 1)), - bias=self.bias, log=self.log) + returned_ = empirical_bures_wasserstein_mapping(Xs, Xt, reg=self.reg, + ws=nx.reshape(self.mu_s, (-1, 1)), + wt=nx.reshape(self.mu_t, (-1, 1)), + bias=self.bias, log=self.log) # deal with the value of log if self.log: diff --git a/ot/gaussian.py b/ot/gaussian.py new file mode 100644 index 0000000..4ffb726 --- /dev/null +++ b/ot/gaussian.py @@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- +""" +Optimal transport for Gaussian distributions +""" + +# Author: Theo Gnassounou +# Remi Flamary +# +# License: MIT License + +from .backend import get_backend +from .utils import dots +from .utils import list_to_array + + +def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): + r"""Return OT linear operator between samples. + + The function estimates the optimal linear operator that aligns the two + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` + and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in + :ref:`[1] ` and discussed in remark 2.29 in + :ref:`[2] `. + + The linear operator from source to target :math:`M` + + .. math:: + M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} + + where : + + .. math:: + \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} + \Sigma_s^{-1/2} + + \mathbf{b} &= \mu_t - \mathbf{A} \mu_s + + Parameters + ---------- + ms : array-like (d,) + mean of the source distribution + mt : array-like (d,) + mean of the target distribution + Cs : array-like (d,) + covariance of the source distribution + Ct : array-like (d,) + covariance of the target distribution + log : bool, optional + record log if True + + + Returns + ------- + A : (d, d) array-like + Linear operator + b : (1, d) array-like + bias + log : dict + log dictionary return only if log==True in parameters + + + .. _references-OT-mapping-linear: + References + ---------- + .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + """ + ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) + nx = get_backend(ms, mt, Cs, Ct) + + Cs12 = nx.sqrtm(Cs) + Cs12inv = nx.inv(Cs12) + + M0 = nx.sqrtm(dots(Cs12, Ct, Cs12)) + + A = dots(Cs12inv, M0, Cs12inv) + + b = mt - nx.dot(ms, A) + + if log: + log = {} + log['Cs12'] = Cs12 + log['Cs12inv'] = Cs12inv + return A, b, log + else: + return A, b + + +def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None, + wt=None, bias=True, log=False): + r"""Return OT linear operator between samples. + + The function estimates the optimal linear operator that aligns the two + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` + and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in + :ref:`[1] ` and discussed in remark 2.29 in + :ref:`[2] `. + + The linear operator from source to target :math:`M` + + .. math:: + M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} + + where : + + .. math:: + \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} + \Sigma_s^{-1/2} + + \mathbf{b} &= \mu_t - \mathbf{A} \mu_s + + Parameters + ---------- + xs : array-like (ns,d) + samples in the source domain + xt : array-like (nt,d) + samples in the target domain + reg : float,optional + regularization added to the diagonals of covariances (>0) + ws : array-like (ns,1), optional + weights for the source samples + wt : array-like (ns,1), optional + weights for the target samples + bias: boolean, optional + estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) + log : bool, optional + record log if True + + + Returns + ------- + A : (d, d) array-like + Linear operator + b : (1, d) array-like + bias + log : dict + log dictionary return only if log==True in parameters + + + .. _references-OT-mapping-linear: + References + ---------- + .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) + + d = xs.shape[1] + + if bias: + mxs = nx.mean(xs, axis=0)[None, :] + mxt = nx.mean(xt, axis=0)[None, :] + + xs = xs - mxs + xt = xt - mxt + else: + mxs = nx.zeros((1, d), type_as=xs) + mxt = nx.zeros((1, d), type_as=xs) + + if ws is None: + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] + + Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) + Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) + + if log: + A, b, log = bures_wasserstein_mapping(mxs, mxt, Cs, Ct, log=log) + log['Cs'] = Cs + log['Ct'] = Ct + return A, b, log + else: + A, b = bures_wasserstein_mapping(mxs, mxt, Cs, Ct) + return A, b + + +def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): + r"""Return Bures Wasserstein distance between samples. + + The function estimates the Bures-Wasserstein distance between two + empirical distributions source :math:`\mu_s` and target :math:`\mu_t`, + discussed in remark 2.31 :ref:`[1] `. + + The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}` + + .. math:: + \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} + + where : + + .. math:: + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + + Parameters + ---------- + ms : array-like (d,) + mean of the source distribution + mt : array-like (d,) + mean of the target distribution + Cs : array-like (d,) + covariance of the source distribution + Ct : array-like (d,) + covariance of the target distribution + log : bool, optional + record log if True + + + Returns + ------- + W : float + Bures Wasserstein distance + log : dict + log dictionary return only if log==True in parameters + + + .. _references-bures-wasserstein-distance: + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + """ + ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) + nx = get_backend(ms, mt, Cs, Ct) + + Cs12 = nx.sqrtm(Cs) + + B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) + W = nx.sqrt(nx.norm(ms - mt)**2 + B) + if log: + log = {} + log['Cs12'] = Cs12 + return W, log + else: + return W + + +def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None, + wt=None, bias=True, log=False): + r"""Return Bures Wasserstein distance from mean and covariance of distribution. + + The function estimates the Bures-Wasserstein distance between two + empirical distributions source :math:`\mu_s` and target :math:`\mu_t`, + discussed in remark 2.31 :ref:`[1] `. + + The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}` + + .. math:: + \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} + + where : + + .. math:: + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + + Parameters + ---------- + xs : array-like (ns,d) + samples in the source domain + xt : array-like (nt,d) + samples in the target domain + reg : float,optional + regularization added to the diagonals of covariances (>0) + ws : array-like (ns,1), optional + weights for the source samples + wt : array-like (ns,1), optional + weights for the target samples + bias: boolean, optional + estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) + log : bool, optional + record log if True + + + Returns + ------- + W : float + Bures Wasserstein distance + log : dict + log dictionary return only if log==True in parameters + + + .. _references-bures-wasserstein-distance: + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) + + d = xs.shape[1] + + if bias: + mxs = nx.mean(xs, axis=0)[None, :] + mxt = nx.mean(xt, axis=0)[None, :] + + xs = xs - mxs + xt = xt - mxt + else: + mxs = nx.zeros((1, d), type_as=xs) + mxt = nx.zeros((1, d), type_as=xs) + + if ws is None: + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] + + Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) + Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) + + if log: + W, log = bures_wasserstein_distance(mxs, mxt, Cs, Ct, log=log) + log['Cs'] = Cs + log['Ct'] = Ct + return W, log + else: + W = bures_wasserstein_distance(mxs, mxt, Cs, Ct) + return W diff --git a/test/test_da.py b/test/test_da.py index 138936f..c5f08d6 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -575,27 +575,6 @@ def test_mapping_transport_class_specific_seed(nx): np.random.seed(None) -@pytest.skip_backend("jax") -@pytest.skip_backend("tf") -def test_linear_mapping(nx): - ns = 50 - nt = 50 - - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) - - Xsb, Xtb = nx.from_numpy(Xs, Xt) - - A, b = ot.da.OT_mapping_linear(Xsb, Xtb) - - Xst = nx.to_numpy(nx.dot(Xsb, A) + b) - - Ct = np.cov(Xt.T) - Cst = np.cov(Xst.T) - - np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) - - @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_linear_mapping_class(nx): diff --git a/test/test_gaussian.py b/test/test_gaussian.py new file mode 100644 index 0000000..be7a806 --- /dev/null +++ b/test/test_gaussian.py @@ -0,0 +1,98 @@ +"""Tests for module gaussian""" + +# Author: Theo Gnassounou +# Remi Flamary +# +# License: MIT License + +import numpy as np + +import pytest + +import ot +from ot.datasets import make_data_classif + + +def test_bures_wasserstein_mapping(nx): + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + ms = np.mean(Xs, axis=0)[None, :] + mt = np.mean(Xt, axis=0)[None, :] + Cs = np.cov(Xs.T) + Ct = np.cov(Xt.T) + + Xsb, msb, mtb, Csb, Ctb = nx.from_numpy(Xs, ms, mt, Cs, Ct) + + A_log, b_log, log = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=True) + A, b = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=False) + + Xst = nx.to_numpy(nx.dot(Xsb, A) + b) + Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log) + + Cst = np.cov(Xst.T) + Cst_log = np.cov(Xst_log.T) + + np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize("bias", [True, False]) +def test_empirical_bures_wasserstein_mapping(nx, bias): + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + if not bias: + ms = np.mean(Xs, axis=0)[None, :] + mt = np.mean(Xt, axis=0)[None, :] + + Xs = Xs - ms + Xt = Xt - mt + + Xsb, Xtb = nx.from_numpy(Xs, Xt) + + A, b, log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=True, bias=bias) + A_log, b_log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=False, bias=bias) + + Xst = nx.to_numpy(nx.dot(Xsb, A) + b) + Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log) + + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) + Cst_log = np.cov(Xst_log.T) + + np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + +def test_bures_wasserstein_distance(nx): + ms, mt = np.array([0]), np.array([10]) + Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32) + msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct) + Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True) + Wb = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=False) + + np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(10, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize("bias", [True, False]) +def test_empirical_bures_wasserstein_distance(nx, bias): + ns = 400 + nt = 400 + + rng = np.random.RandomState(10) + Xs = rng.normal(0, 1, ns)[:, np.newaxis] + Xt = rng.normal(10 * bias, 1, nt)[:, np.newaxis] + + Xsb, Xtb = nx.from_numpy(Xs, Xt) + Wb_log, log = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=True, bias=bias) + Wb = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=False, bias=bias) + + np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) -- cgit v1.2.3 From 80e3c23bc968f866fd20344ddc443a3c7fcb3b0d Mon Sep 17 00:00:00 2001 From: Clément Bonet <32179275+clbonet@users.noreply.github.com> Date: Thu, 23 Feb 2023 08:31:01 +0100 Subject: [WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434) * W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn --- CONTRIBUTORS.md | 1 + README.md | 10 +- RELEASES.md | 4 + examples/backends/plot_ssw_unif_torch.py | 153 ++++++ examples/plot_compute_wasserstein_circle.py | 161 ++++++ examples/sliced-wasserstein/plot_variance_ssw.py | 111 ++++ ot/__init__.py | 13 +- ot/backend.py | 204 +++++++- ot/lp/__init__.py | 7 +- ot/lp/solver_1d.py | 627 ++++++++++++++++++++++- ot/sliced.py | 185 ++++++- ot/utils.py | 30 ++ test/test_1d_solver.py | 127 +++++ test/test_backend.py | 46 ++ test/test_sliced.py | 186 +++++++ test/test_utils.py | 10 + 16 files changed, 1852 insertions(+), 23 deletions(-) create mode 100644 examples/backends/plot_ssw_unif_torch.py create mode 100644 examples/plot_compute_wasserstein_circle.py create mode 100644 examples/sliced-wasserstein/plot_variance_ssw.py (limited to 'CONTRIBUTORS.md') diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 67d8337..1437821 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -41,6 +41,7 @@ The contributors to this library are: * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) +* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein) ## Acknowledgments diff --git a/README.md b/README.md index 7c9475b..d5e6854 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ POT provides the following generic OT solvers (links to examples): * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. +* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45] +* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] * [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. @@ -292,4 +294,10 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. -[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. \ No newline at end of file +[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + +[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + +[45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + +[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 4ed3625..f8ef653 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,6 +4,10 @@ #### New features +- Added the spherical sliced-Wasserstein discrepancy in `ot.sliced.sliced_wasserstein_sphere` and `ot.sliced.sliced_wasserstein_sphere_unif` + examples (PR #434) +- Added the Wasserstein distance on the circle in ``ot.lp.solver_1d.wasserstein_circle`` (PR #434) +- Added the Wasserstein distance on the circle (for p>=1) in `ot.lp.solver_1d.binary_search_circle` + examples (PR #434) +- Added the 2-Wasserstein distance on the circle w.r.t a uniform distribution in `ot.lp.solver_1d.semidiscrete_wasserstein2_unif_circle` (PR #434) - Added Bures Wasserstein distance in `ot.gaussian` (PR ##428) - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) - Added Free Support Sinkhorn Barycenter + example (PR #387) diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py new file mode 100644 index 0000000..d1de5a9 --- /dev/null +++ b/examples/backends/plot_ssw_unif_torch.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +r""" +================================================ +Spherical Sliced-Wasserstein Embedding on Sphere +================================================ + +Here, we aim at transforming samples into a uniform +distribution on the sphere by minimizing SSW: + +.. math:: + \min_{x} SSW_2(\nu, \frac{1}{n}\sum_{i=1}^n \delta_{x_i}) + +where :math:`\nu=\mathrm{Unif}(S^1)`. + +""" + +# Author: Clément Bonet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import matplotlib.animation as animation +import torch +import torch.nn.functional as F + +import ot + + +# %% +# Data generation +# --------------- + +torch.manual_seed(1) + +N = 1000 +x0 = torch.rand(N, 3) +x0 = F.normalize(x0, dim=-1) + + +# %% +# Plot data +# --------- + +def plot_sphere(ax): + xlist = np.linspace(-1.0, 1.0, 50) + ylist = np.linspace(-1.0, 1.0, 50) + r = np.linspace(1.0, 1.0, 50) + X, Y = np.meshgrid(xlist, ylist) + + Z = np.sqrt(r**2 - X**2 - Y**2) + + ax.plot_wireframe(X, Y, Z, color="gray", alpha=.3) + ax.plot_wireframe(X, Y, -Z, color="gray", alpha=.3) # Now plot the bottom half + + +# plot the distributions +pl.figure(1) +ax = pl.axes(projection='3d') +plot_sphere(ax) +ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label='Data samples', alpha=0.5) +ax.set_title('Data distribution') +ax.legend() + + +# %% +# Gradient descent +# ---------------- + +x = x0.clone() +x.requires_grad_(True) + +n_iter = 500 +lr = 100 + +losses = [] +xvisu = torch.zeros(n_iter, N, 3) + +for i in range(n_iter): + sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500) + grad_x = torch.autograd.grad(sw, x)[0] + + x = x - lr * grad_x + x = F.normalize(x, p=2, dim=1) + + losses.append(sw.item()) + xvisu[i, :, :] = x.detach().clone() + + if i % 100 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + +pl.figure(1) +pl.semilogy(losses) +pl.grid() +pl.title('SSW') +pl.xlabel("Iterations") + + +# %% +# Plot trajectories of generated samples along iterations +# ------------------------------------------------------- + +ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499] + +fig = pl.figure(3, (10, 10)) +for i in range(9): + # pl.subplot(3, 3, i + 1) + # ax = pl.axes(projection='3d') + ax = fig.add_subplot(3, 3, i + 1, projection='3d') + plot_sphere(ax) + ax.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], xvisu[ivisu[i], :, 2], label='Data samples', alpha=0.5) + ax.set_title('Iter. {}'.format(ivisu[i])) + #ax.axis("off") + if i == 0: + ax.legend() + + +# %% +# Animate trajectories of generated samples along iteration +# ------------------------------------------------------- + +pl.figure(4, (8, 8)) + + +def _update_plot(i): + i = 3 * i + pl.clf() + ax = pl.axes(projection='3d') + plot_sphere(ax) + ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples$', alpha=0.5) + ax.axis("off") + ax.set_xlim((-1.5, 1.5)) + ax.set_ylim((-1.5, 1.5)) + ax.set_title('Iter. {}'.format(i)) + return 1 + + +print(xvisu.shape) + +i = 0 +ax = pl.axes(projection='3d') +plot_sphere(ax) +ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples from $G\#\mu_n$', alpha=0.5) +ax.axis("off") +ax.set_xlim((-1.5, 1.5)) +ax.set_ylim((-1.5, 1.5)) +ax.set_title('Iter. {}'.format(ivisu[i])) + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000) +# %% diff --git a/examples/plot_compute_wasserstein_circle.py b/examples/plot_compute_wasserstein_circle.py new file mode 100644 index 0000000..3ede96f --- /dev/null +++ b/examples/plot_compute_wasserstein_circle.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +""" +========================= +OT distance on the Circle +========================= + +Shows how to compute the Wasserstein distance on the circle + + +""" + +# Author: Clément Bonet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot + +from scipy.special import iv + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + + +def pdf_von_Mises(theta, mu, kappa): + pdf = np.exp(kappa * np.cos(theta - mu)) / (2.0 * np.pi * iv(0, kappa)) + return pdf + + +t = np.linspace(0, 2 * np.pi, 1000, endpoint=False) + +mu1 = 1 +kappa1 = 20 + +mu_targets = np.linspace(mu1, mu1 + 2 * np.pi, 10) + + +pdf1 = pdf_von_Mises(t, mu1, kappa1) + + +pl.figure(1) +for k, mu in enumerate(mu_targets): + pdf_t = pdf_von_Mises(t, mu, kappa1) + if k == 0: + label = "Source distributions" + else: + label = None + pl.plot(t / (2 * np.pi), pdf_t, c='b', label=label) + +pl.plot(t / (2 * np.pi), pdf1, c="r", label="Target distribution") +pl.legend() + +mu2 = 0 +kappa2 = kappa1 + +x1 = np.random.vonmises(mu1, kappa1, size=(10,)) + np.pi +x2 = np.random.vonmises(mu2, kappa2, size=(10,)) + np.pi + +angles = np.linspace(0, 2 * np.pi, 150) + +pl.figure(2) +pl.plot(np.cos(angles), np.sin(angles), c="k") +pl.xlim(-1.25, 1.25) +pl.ylim(-1.25, 1.25) +pl.scatter(np.cos(x1), np.sin(x1), c="b") +pl.scatter(np.cos(x2), np.sin(x2), c="r") + +######################################################################################### +# Compare the Euclidean Wasserstein distance with the Wasserstein distance on the circle +# --------------------------------------------------------------------------------------- +# This examples illustrates the periodicity of the Wasserstein distance on the circle. +# We choose as target distribution a von Mises distribution with mean :math:`\mu_{\mathrm{target}}` +# and :math:`\kappa=20`. Then, we compare the distances with samples obtained from a von Mises distribution +# with parameters :math:`\mu_{\mathrm{source}}` and :math:`\kappa=20`. +# The Wasserstein distance on the circle takes into account the periodicity +# and attains its maximum in :math:`\mu_{\mathrm{target}}+1` (the antipodal point) contrary to the +# Euclidean version. + +#%% Compute and plot distributions + +mu_targets = np.linspace(0, 2 * np.pi, 200) +xs = np.random.vonmises(mu1 - np.pi, kappa1, size=(500,)) + np.pi + +n_try = 5 + +xts = np.zeros((n_try, 200, 500)) +for i in range(n_try): + for k, mu in enumerate(mu_targets): + # np.random.vonmises deals with data on [-pi, pi[ + xt = np.random.vonmises(mu - np.pi, kappa2, size=(500,)) + np.pi + xts[i, k] = xt + +# Put data on S^1=[0,1[ +xts2 = xts / (2 * np.pi) +xs2 = np.concatenate([xs[None] for k in range(200)], axis=0) / (2 * np.pi) + +L_w2_circle = np.zeros((n_try, 200)) +L_w2 = np.zeros((n_try, 200)) + +for i in range(n_try): + w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2) + w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2) + + L_w2_circle[i] = w2_circle + L_w2[i] = w2 + +m_w2_circle = np.mean(L_w2_circle, axis=0) +std_w2_circle = np.std(L_w2_circle, axis=0) + +m_w2 = np.mean(L_w2, axis=0) +std_w2 = np.std(L_w2, axis=0) + +pl.figure(1) +pl.plot(mu_targets / (2 * np.pi), m_w2_circle, label="Wasserstein circle") +pl.fill_between(mu_targets / (2 * np.pi), m_w2_circle - 2 * std_w2_circle, m_w2_circle + 2 * std_w2_circle, alpha=0.5) +pl.plot(mu_targets / (2 * np.pi), m_w2, label="Euclidean Wasserstein") +pl.fill_between(mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5) +pl.vlines(x=[mu1 / (2 * np.pi)], ymin=0, ymax=np.max(w2), linestyle="--", color="k", label=r"$\mu_{\mathrm{target}}$") +pl.legend() +pl.xlabel(r"$\mu_{\mathrm{source}}$") +pl.show() + + +######################################################################## +# Wasserstein distance between von Mises and uniform for different kappa +# ---------------------------------------------------------------------- +# When :math:`\kappa=0`, the von Mises distribution is the uniform distribution on :math:`S^1`. + +#%% Compute Wasserstein between Von Mises and uniform + +kappas = np.logspace(-5, 2, 100) +n_try = 20 + +xts = np.zeros((n_try, 100, 500)) +for i in range(n_try): + for k, kappa in enumerate(kappas): + # np.random.vonmises deals with data on [-pi, pi[ + xt = np.random.vonmises(0, kappa, size=(500,)) + np.pi + xts[i, k] = xt / (2 * np.pi) + +L_w2 = np.zeros((n_try, 100)) +for i in range(n_try): + L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T) + +m_w2 = np.mean(L_w2, axis=0) +std_w2 = np.std(L_w2, axis=0) + +pl.figure(1) +pl.plot(kappas, m_w2) +pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5) +pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$") +pl.xlabel(r"$\kappa$") +pl.show() + +# %% diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py new file mode 100644 index 0000000..83d458f --- /dev/null +++ b/examples/sliced-wasserstein/plot_variance_ssw.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +Spherical Sliced Wasserstein on distributions in S^2 +==================================================== + +This example illustrates the computation of the spherical sliced Wasserstein discrepancy as +proposed in [46]. + +[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). 'Spherical Sliced-Wasserstein". International Conference on Learning Representations. + +""" + +# Author: Clément Bonet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import matplotlib.pylab as pl +import numpy as np + +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 500 # nb samples + +xs = np.random.randn(n, 3) +xt = np.random.randn(n, 3) + +xs = xs / np.sqrt(np.sum(xs**2, -1, keepdims=True)) +xt = xt / np.sqrt(np.sum(xt**2, -1, keepdims=True)) + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +############################################################################## +# Plot data +# --------- + +# %% plot samples + +fig = pl.figure(figsize=(10, 10)) +ax = pl.axes(projection='3d') +ax.grid(False) + +u, v = np.mgrid[0:2 * np.pi:30j, 0:np.pi:30j] +x = np.cos(u) * np.sin(v) +y = np.sin(u) * np.sin(v) +z = np.cos(v) +ax.plot_surface(x, y, z, color="gray", alpha=0.03) +ax.plot_wireframe(x, y, z, linewidth=1, alpha=0.25, color="gray") + +ax.scatter(xs[:, 0], xs[:, 1], xs[:, 2], label="Source") +ax.scatter(xt[:, 0], xt[:, 1], xt[:, 2], label="Target") + +fs = 10 +# Labels +ax.set_xlabel('x', fontsize=fs) +ax.set_ylabel('y', fontsize=fs) +ax.set_zlabel('z', fontsize=fs) + +ax.view_init(20, 120) +ax.set_xlim(-1.5, 1.5) +ax.set_ylim(-1.5, 1.5) +ax.set_zlim(-1.5, 1.5) + +# Ticks +ax.set_xticks([-1, 0, 1]) +ax.set_yticks([-1, 0, 1]) +ax.set_zticks([-1, 0, 1]) + +pl.legend(loc=0) +pl.title("Source and Target distribution") + +############################################################################### +# Spherical Sliced Wasserstein for different seeds and number of projections +# -------------------------------------------------------------------------- + +n_seed = 50 +n_projections_arr = np.logspace(0, 3, 25, dtype=int) +res = np.empty((n_seed, 25)) + +# %% Compute statistics +for seed in range(n_seed): + for i, n_projections in enumerate(n_projections_arr): + res[seed, i] = ot.sliced_wasserstein_sphere(xs, xt, a, b, n_projections, seed=seed, p=1) + +res_mean = np.mean(res, axis=0) +res_std = np.std(res, axis=0) + +############################################################################### +# Plot Spherical Sliced Wasserstein +# --------------------------------- + +pl.figure(2) +pl.plot(n_projections_arr, res_mean, label=r"$SSW_1$") +pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) + +pl.legend() +pl.xscale('log') + +pl.xlabel("Number of projections") +pl.ylabel("Distance") +pl.title('Spherical Sliced Wasserstein Distance with 95% confidence inverval') + +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 0b55e0c..45d5cfa 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -38,12 +38,15 @@ from . import solvers from . import gaussian # OT functions -from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d +from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, + binary_search_circle, wasserstein_circle, + semidiscrete_wasserstein2_unif_circle) from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm -from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance +from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance, + sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif) from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport @@ -60,8 +63,10 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', - 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', + 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'factored_optimal_transport', 'solve', - 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers'] + 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', + 'binary_search_circle', 'wasserstein_circle', + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] diff --git a/ot/backend.py b/ot/backend.py index 337e040..0779243 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -534,9 +534,9 @@ class Backend(): """ raise NotImplementedError() - def zero_pad(self, a, pad_width): + def zero_pad(self, a, pad_width, value=0): r""" - Pads a tensor. + Pads a tensor with a given value (0 by default). This function follows the api from :any:`numpy.pad` @@ -895,6 +895,62 @@ class Backend(): """ raise NotImplementedError() + def tile(self, a, reps): + r""" + Construct an array by repeating a the number of times given by reps + + See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html + """ + raise NotImplementedError() + + def floor(self, a): + r""" + Return the floor of the input element-wise + + See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html + """ + raise NotImplementedError() + + def prod(self, a, axis=None): + r""" + Return the product of all elements. + + See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html + """ + raise NotImplementedError() + + def sort2(self, a, axis=None): + r""" + Return the sorted array and the indices to sort the array + + See: https://pytorch.org/docs/stable/generated/torch.sort.html + """ + raise NotImplementedError() + + def qr(self, a): + r""" + Return the QR factorization + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html + """ + raise NotImplementedError() + + def atan2(self, a, b): + r""" + Element wise arctangent + + See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html + """ + raise NotImplementedError() + + def transpose(self, a, axes=None): + r""" + Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped. + + See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1039,8 +1095,8 @@ class NumpyBackend(Backend): def concatenate(self, arrays, axis=0): return np.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return np.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return np.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return np.argmax(a, axis=axis) @@ -1185,6 +1241,44 @@ class NumpyBackend(Backend): def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return np.tile(a, reps) + + def floor(self, a): + return np.floor(a) + + def prod(self, a, axis=0): + return np.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + np_version = tuple([int(k) for k in np.__version__.split(".")]) + if np_version < (1, 22, 0): + M, N = a.shape[-2], a.shape[-1] + K = min(M, N) + + if len(a.shape) >= 3: + n = a.shape[0] + + qs, rs = np.zeros((n, M, K)), np.zeros((n, K, N)) + + for i in range(a.shape[0]): + qs[i], rs[i] = np.linalg.qr(a[i]) + + else: + return np.linalg.qr(a) + + return qs, rs + return np.linalg.qr(a) + + def atan2(self, a, b): + return np.arctan2(a, b) + + def transpose(self, a, axes=None): + return np.transpose(a, axes) + class JaxBackend(Backend): """ @@ -1351,8 +1445,8 @@ class JaxBackend(Backend): def concatenate(self, arrays, axis=0): return jnp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return jnp.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return jnp.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return jnp.argmax(a, axis=axis) @@ -1511,6 +1605,27 @@ class JaxBackend(Backend): def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return jnp.tile(a, reps) + + def floor(self, a): + return jnp.floor(a) + + def prod(self, a, axis=0): + return jnp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return jnp.linalg.qr(a) + + def atan2(self, a, b): + return jnp.arctan2(a, b) + + def transpose(self, a, axes=None): + return jnp.transpose(a, axes) + class TorchBackend(Backend): """ @@ -1729,13 +1844,13 @@ class TorchBackend(Backend): def concatenate(self, arrays, axis=0): return torch.cat(arrays, dim=axis) - def zero_pad(self, a, pad_width): + def zero_pad(self, a, pad_width, value=0): from torch.nn.functional import pad # pad_width is an array of ndim tuples indicating how many 0 before and after # we need to add. We first need to make it compliant with torch syntax, that # starts with the last dim, then second last, etc. how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl) - return pad(a, how_pad) + return pad(a, how_pad, value=value) def argmax(self, a, axis=None): return torch.argmax(a, dim=axis) @@ -1934,6 +2049,29 @@ class TorchBackend(Backend): def is_floating_point(self, a): return a.dtype.is_floating_point + def tile(self, a, reps): + return a.repeat(reps) + + def floor(self, a): + return torch.floor(a) + + def prod(self, a, axis=0): + return torch.prod(a, dim=axis) + + def sort2(self, a, axis=-1): + return torch.sort(a, axis) + + def qr(self, a): + return torch.linalg.qr(a) + + def atan2(self, a, b): + return torch.atan2(a, b) + + def transpose(self, a, axes=None): + if axes is None: + axes = tuple(range(a.ndim)[::-1]) + return a.permute(axes) + class CupyBackend(Backend): # pragma: no cover """ @@ -2096,8 +2234,8 @@ class CupyBackend(Backend): # pragma: no cover def concatenate(self, arrays, axis=0): return cp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return cp.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return cp.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return cp.argmax(a, axis=axis) @@ -2284,6 +2422,27 @@ class CupyBackend(Backend): # pragma: no cover def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return cp.tile(a, reps) + + def floor(self, a): + return cp.floor(a) + + def prod(self, a, axis=0): + return cp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return cp.linalg.qr(a) + + def atan2(self, a, b): + return cp.arctan2(a, b) + + def transpose(self, a, axes=None): + return cp.transpose(a, axes) + class TensorflowBackend(Backend): @@ -2454,8 +2613,8 @@ class TensorflowBackend(Backend): def concatenate(self, arrays, axis=0): return tnp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return tnp.pad(a, pad_width, mode="constant") + def zero_pad(self, a, pad_width, value=0): + return tnp.pad(a, pad_width, mode="constant", constant_values=value) def argmax(self, a, axis=None): return tnp.argmax(a, axis=axis) @@ -2646,3 +2805,24 @@ class TensorflowBackend(Backend): def is_floating_point(self, a): return a.dtype.is_floating + + def tile(self, a, reps): + return tnp.tile(a, reps) + + def floor(self, a): + return tf.floor(a) + + def prod(self, a, axis=0): + return tnp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return tf.linalg.qr(a) + + def atan2(self, a, b): + return tf.math.atan2(a, b) + + def transpose(self, a, axes=None): + return tf.transpose(a, perm=axes) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 17411d0..7d0640f 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -20,14 +20,17 @@ from .cvx import barycenter # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from .solver_1d import emd_1d, emd2_1d, wasserstein_1d +from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d, + binary_search_circle, wasserstein_circle, + semidiscrete_wasserstein2_unif_circle) from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter'] + 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter', + 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle'] def check_number_threads(numThreads): diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 43763a9..e7add89 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -53,7 +53,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ distributions .. math: - OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq It is formally the p-Wasserstein distance raised to the power p. We do so in a vectorized way by first building the individual quantile functions then integrating them. @@ -365,3 +365,628 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, log_emd = {'G': G} return cost, log_emd return cost + + +def roll_cols(M, shifts): + r""" + Utils functions which allow to shift the order of each row of a 2d matrix + + Parameters + ---------- + M : (nr, nc) ndarray + Matrix to shift + shifts: int or (nr,) ndarray + + Returns + ------- + Shifted array + + Examples + -------- + >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]]) + >>> roll_cols(M, 2) + array([[2, 3, 1], + [5, 6, 4], + [8, 9, 7]]) + >>> roll_cols(M, np.array([[1],[2],[1]])) + array([[3, 1, 2], + [5, 6, 4], + [9, 7, 8]]) + + References + ---------- + https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch + """ + nx = get_backend(M) + + n_rows, n_cols = M.shape + + arange1 = nx.tile(nx.reshape(nx.arange(n_cols), (1, n_cols)), (n_rows, 1)) + arange2 = (arange1 - shifts) % n_cols + + return nx.take_along_axis(M, arange2, 1) + + +def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): + r""" Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1]) + + Parameters + ---------- + theta: array-like, shape (n_batch, n) + Cuts on the circle + u_values: array-like, shape (n_batch, n) + locations of the first empirical distribution + v_values: array-like, shape (n_batch, n) + locations of the second empirical distribution + u_cdf: array-like, shape (n_batch, n) + cdf of the first empirical distribution + v_cdf: array-like, shape (n_batch, n) + cdf of the second empirical distribution + p: float, optional = 2 + Power p used for computing the Wasserstein distance + + Returns + ------- + dCp: array-like, shape (n_batch, 1) + The batched right derivative + dCm: array-like, shape (n_batch, 1) + The batched left derivative + + References + --------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) + + v_values = nx.copy(v_values) + + n = u_values.shape[-1] + m_batch, m = v_values.shape + + v_cdf_theta = v_cdf - (theta - nx.floor(theta)) + + mask_p = v_cdf_theta >= 0 + mask_n = v_cdf_theta < 0 + + v_values[mask_n] += nx.floor(theta)[mask_n] + 1 + v_values[mask_p] += nx.floor(theta)[mask_p] + + if nx.any(mask_n) and nx.any(mask_p): + v_cdf_theta[mask_n] += 1 + + v_cdf_theta2 = nx.copy(v_cdf_theta) + v_cdf_theta2[mask_n] = np.inf + shift = (-nx.argmin(v_cdf_theta2, axis=-1)) + + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) + v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) + + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + u_cdf = u_cdf.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + + # quantiles of F_u evaluated in F_v^\theta + u_index = nx.searchsorted(u_cdf, v_cdf_theta) + u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1) + + # Deal with 1 + u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1) + u_valuesm = nx.concatenate([u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1) + + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + u_cdfm = u_cdfm.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + + u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right") + u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1) + + dCp = nx.sum(nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), axis=-1) + + dCm = nx.sum(nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), axis=-1) + + return dCp.reshape(-1, 1), dCm.reshape(-1, 1) + + +def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): + r""" Computes the the cost (Equation (6.2) of [1]) + + Parameters + ---------- + theta: array-like, shape (n_batch, n) + Cuts on the circle + u_values: array-like, shape (n_batch, n) + locations of the first empirical distribution + v_values: array-like, shape (n_batch, n) + locations of the second empirical distribution + u_cdf: array-like, shape (n_batch, n) + cdf of the first empirical distribution + v_cdf: array-like, shape (n_batch, n) + cdf of the second empirical distribution + p: float, optional = 2 + Power p used for computing the Wasserstein distance + + Returns + ------- + ot_cost: array-like, shape (n_batch,) + OT cost evaluated at theta + + References + --------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) + + v_values = nx.copy(v_values) + + m_batch, m = v_values.shape + n_batch, n = u_values.shape + + v_cdf_theta = v_cdf - (theta - nx.floor(theta)) + + mask_p = v_cdf_theta >= 0 + mask_n = v_cdf_theta < 0 + + v_values[mask_n] += nx.floor(theta)[mask_n] + 1 + v_values[mask_p] += nx.floor(theta)[mask_p] + + if nx.any(mask_n) and nx.any(mask_p): + v_cdf_theta[mask_n] += 1 + + # Put negative values at the end + v_cdf_theta2 = nx.copy(v_cdf_theta) + v_cdf_theta2[mask_n] = np.inf + shift = (-nx.argmin(v_cdf_theta2, axis=-1)) + + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) + v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) + + # Compute absciss + cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1) + cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)]) + + delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1] + + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + u_cdf = u_cdf.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + cdf_axis = cdf_axis.contiguous() + + # Compute icdf + u_index = nx.searchsorted(u_cdf, cdf_axis) + u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1) + + v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) + v_index = nx.searchsorted(v_cdf_theta, cdf_axis) + v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1) + + if p == 1: + ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1) + else: + ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1) + + return ot_cost + + +def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, + Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True, + log=False): + r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates + using e.g. the atan2 function. + + .. math:: + W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + + where: + + - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v` + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + using e.g. ot.utils.get_coordinate_circle(x) + + The function runs on backend but tensorflow is not supported. + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + p : float, optional (default=1) + Power p used for computing the Wasserstein distance + Lm : int, optional + Lower bound dC + Lp : int, optional + Upper bound dC + tm: float, optional + Lower bound theta + tp: float, optional + Upper bound theta + eps: float, optional + Stopping condition + require_sort: bool, optional + If True, sort the values. + log: bool, optional + If True, returns also the optimal theta + + Returns + ------- + loss: float + Cost associated to the optimal transportation + log: dict, optional + log dictionary returned only if log==True in parameters + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> binary_search_circle(u.T, v.T, p=1) + array([0.1]) + + References + ---------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html + """ + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1], + v_values.shape[1])) + + u_values = u_values % 1 + v_values = v_values % 1 + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + u_cdf = nx.cumsum(u_weights, 0).T + v_cdf = nx.cumsum(v_weights, 0).T + + u_values = u_values.T + v_values = v_values.T + + L = max(Lm, Lp) + + tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) + tm = nx.tile(tm, (1, m)) + tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) + tp = nx.tile(tp, (1, m)) + tc = (tm + tp) / 2 + + done = nx.zeros((u_values.shape[0], m)) + + cpt = 0 + while nx.any(1 - done): + cpt += 1 + + dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) + done = ((dCp * dCm) <= 0) * 1 + + mask = ((tp - tm) < eps / L) * (1 - done) + + if nx.any(mask): + # can probably be improved by computing only relevant values + dCptp, dCmtp = derivative_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p) + dCptm, dCmtm = derivative_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p) + Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1) + Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1) + + mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) + tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0] + done[nx.prod(mask, axis=-1) > 0] = 1 + elif nx.any(1 - done): + tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0] + tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0] + tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2 + + w = ot_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) + + if log: + return w, {"optimal_theta": tc[:, 0]} + return w + + +def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True): + r"""Computes the 1-Wasserstein distance on the circle using the level median [45]. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates + using e.g. the atan2 function. + The function runs on backend but tensorflow is not supported. + + .. math:: + W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + require_sort: bool, optional + If True, sort the values. + + Returns + ------- + loss: float + Cost associated to the optimal transportation + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> wasserstein1_circle(u.T, v.T) + array([0.1]) + + References + ---------- + .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + """ + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1], + v_values.shape[1])) + + u_values = u_values % 1 + v_values = v_values % 1 + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0) + + cdf_diff = nx.cumsum(nx.take_along_axis(nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0), 0) + cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0) + + values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1) + delta = values_sorted[1:, ...] - values_sorted[:-1, ...] + weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0) + + sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5 + sum_weights[sum_weights < 0] = np.inf + inds = nx.argmin(sum_weights, axis=0) + + levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0) + + return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0) + + +def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, + Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): + r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or + the binary search algorithm proposed in [44] otherwise. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates + using e.g. the atan2 function. + + General loss returned: + + .. math:: + OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + + For p=1, [45] + + .. math:: + W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + using e.g. ot.utils.get_coordinate_circle(x) + + The function runs on backend but tensorflow is not supported. + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + p : float, optional (default=1) + Power p used for computing the Wasserstein distance + Lm : int, optional + Lower bound dC. For p>1. + Lp : int, optional + Upper bound dC. For p>1. + tm: float, optional + Lower bound theta. For p>1. + tp: float, optional + Upper bound theta. For p>1. + eps: float, optional + Stopping condition. For p>1. + require_sort: bool, optional + If True, sort the values. + + Returns + ------- + loss: float + Cost associated to the optimal transportation + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> wasserstein_circle(u.T, v.T) + array([0.1]) + + References + ---------- + .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if p == 1: + return wasserstein1_circle(u_values, v_values, u_weights, v_weights, require_sort) + + return binary_search_circle(u_values, v_values, u_weights, v_weights, + p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps, + require_sort=require_sort) + + +def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): + r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1` + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates + using e.g. the atan2 function. + + .. math:: + W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12} + + where: + + - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}` + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}, + + using e.g. ot.utils.get_coordinate_circle(x) + + Parameters + ---------- + u_values: ndarray, shape (n, ...) + Samples + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + + Returns + ------- + loss: float + Cost associated to the optimal transportation + + Examples + -------- + >>> x0 = np.array([[0], [0.2], [0.4]]) + >>> semidiscrete_wasserstein2_unif_circle(x0) + array([0.02111111]) + + References + ---------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + + if u_weights is not None: + nx = get_backend(u_values, u_weights) + else: + nx = get_backend(u_values) + + n = u_values.shape[0] + + u_values = u_values % 1 + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + u_values = nx.sort(u_values, 0) + u_cdf = nx.cumsum(u_weights, 0) + u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) + + cpt1 = nx.sum(u_weights * u_values**2, axis=0) + u_mean = nx.sum(u_weights * u_values, axis=0) + + ns = 1 - u_weights - 2 * u_cdf[:-1] + cpt2 = nx.sum(u_values * u_weights * ns, axis=0) + + return cpt1 - u_mean**2 + cpt2 + 1 / 12 diff --git a/ot/sliced.py b/ot/sliced.py index 20891a4..077ff0b 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -12,7 +12,8 @@ Sliced OT Distances import numpy as np from .backend import get_backend, NumpyBackend -from .utils import list_to_array +from .utils import list_to_array, get_coordinate_circle +from .lp import wasserstein_circle, semidiscrete_wasserstein2_unif_circle def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): @@ -107,7 +108,6 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, -------- >>> n_samples_a = 20 - >>> reg = 0.1 >>> X = np.random.normal(0., 1., (n_samples_a, 5)) >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE 0.0 @@ -208,7 +208,6 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, -------- >>> n_samples_a = 20 - >>> reg = 0.1 >>> X = np.random.normal(0., 1., (n_samples_a, 5)) >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE 0.0 @@ -258,3 +257,183 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, if log: return res, {"projections": projections, "projected_emds": projected_emd} return res + + +def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, + p=2, seed=None, log=False): + r""" + Compute the spherical sliced-Wasserstein discrepancy. + + .. math:: + SSW_p(\mu,\nu) = \left(\int_{\mathbb{V}_{d,2}} W_p^p(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right)^{\frac{1}{p}} + + where: + + - :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}` + + The function runs on backend but tensorflow is not supported. + + Parameters + ---------- + X_s: ndarray, shape (n_samples_a, dim) + Samples in the source domain + X_t: ndarray, shape (n_samples_b, dim) + Samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional (default=2) + Power p used for computing the spherical sliced Wasserstein + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_sphere returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Spherical Sliced Wasserstein Cost + log: dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + >>> n_samples_a = 20 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True)) + >>> sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + if a is not None and b is not None: + nx = get_backend(X_s, X_t, a, b) + else: + nx = get_backend(X_s, X_t) + + n, d = X_s.shape + m, _ = X_t.shape + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("X_s is not on the sphere.") + if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("Xt is not on the sphere.") + + # Uniforms and independent samples on the Stiefel manifold V_{d,2} + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + Z = seed.randn(n_projections, d, 2) + else: + if seed is not None: + nx.seed(seed) + Z = nx.randn(n_projections, d, 2, type_as=X_s) + + projections, _ = nx.qr(Z) + + # Projection on S^1 + # Projection on plane + Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) + Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1)) + + # Projection on sphere + Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) + Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True)) + + # Get coordinates on [0,1[ + Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) + Xpt_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m)) + + projected_emd = wasserstein_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p) + res = nx.mean(projected_emd) ** (1 / p) + + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res + + +def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log=False): + r"""Compute the 2-spherical sliced wasserstein w.r.t. a uniform distribution. + + .. math:: + SSW_2(\mu_n, \nu) + + where + + - :math:`\mu_n=\sum_{i=1}^n \alpha_i \delta_{x_i}` + - :math:`\nu=\mathrm{Unif}(S^1)` + + Parameters + ---------- + X_s: ndarray, shape (n_samples_a, dim) + Samples in the source domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Spherical Sliced Wasserstein Cost + log: dict, optional + log dictionary return only if log==True in parameters + + Examples + --------- + >>> np.random.seed(42) + >>> x0 = np.random.randn(500,3) + >>> x0 = x0 / np.sqrt(np.sum(x0**2, -1, keepdims=True)) + >>> ssw = sliced_wasserstein_sphere_unif(x0, seed=42) + >>> np.allclose(sliced_wasserstein_sphere_unif(x0, seed=42), 0.01734, atol=1e-3) + True + + References: + ----------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + if a is not None: + nx = get_backend(X_s, a) + else: + nx = get_backend(X_s) + + n, d = X_s.shape + + if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("X_s is not on the sphere.") + + # Uniforms and independent samples on the Stiefel manifold V_{d,2} + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + Z = seed.randn(n_projections, d, 2) + else: + if seed is not None: + nx.seed(seed) + Z = nx.randn(n_projections, d, 2, type_as=X_s) + + projections, _ = nx.qr(Z) + + # Projection on S^1 + # Projection on plane + Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) + # Projection on sphere + Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) + # Get coordinates on [0,1[ + Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) + + projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a) + res = nx.mean(projected_emd) ** (1 / 2) + + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res diff --git a/ot/utils.py b/ot/utils.py index 9093f09..3423a7e 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -375,6 +375,36 @@ def check_random_state(seed): ' instance'.format(seed)) +def get_coordinate_circle(x): + r"""For :math:`x\in S^1 \subset \mathbb{R}^2`, returns the coordinates in + turn (in [0,1[). + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + Parameters + ---------- + x: ndarray, shape (n, 2) + Samples on the circle with ambient coordinates + + Returns + ------- + x_t: ndarray, shape (n,) + Coordinates on [0,1[ + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]]) * (2 * np.pi) + >>> x1, y1 = np.cos(u), np.sin(u) + >>> x = np.concatenate([x1, y1]).T + >>> get_coordinate_circle(x) + array([0.2, 0.5, 0.8]) + """ + nx = get_backend(x) + x_t = (nx.atan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi) + return x_t + + class deprecated(object): r"""Decorator to mark a function or class as deprecated. diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 20f307a..21abd1d 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -218,3 +218,130 @@ def test_emd1d_device_tf(): nx.assert_same_dtype_device(xb, emd) nx.assert_same_dtype_device(xb, emd2) assert nx.dtype_device(emd)[1].startswith("GPU") + + +def test_wasserstein_1d_circle(): + # test binary_search_circle and wasserstein_circle give similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + + wass1 = ot.emd2(w_u, w_v, M1) + + wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1) + w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1) + + M2 = M1**2 + wass2 = ot.emd2(w_u, w_v, M2) + wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2) + w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2) + + # check loss is similar + np.testing.assert_allclose(wass1, wass1_bsc) + np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2) + np.testing.assert_allclose(wass2, wass2_bsc) + np.testing.assert_allclose(wass2, w2_circle) + + +@pytest.skip_backend("tf") +def test_wasserstein1d_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1) + w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2) + + nx.assert_same_dtype_device(xb, w1) + nx.assert_same_dtype_device(xb, w2_bsc) + + +def test_wasserstein_1d_unif_circle(): + # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle + n = 20 + m = 50000 + + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + # w_u = rng.uniform(0., 1., n) + # w_u = w_u / w_u.sum() + + w_u = ot.utils.unif(n) + w_v = ot.utils.unif(m) + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + wass2 = ot.emd2(w_u, w_v, M1**2) + + wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15) + wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u) + + # check loss is similar + np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3) + np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3) + + +def test_wasserstein1d_unif_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp) + + w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub) + + nx.assert_same_dtype_device(xb, w2) + + +def test_binary_search_circle_log(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True) + optimal_thetas = log["optimal_theta"] + + assert optimal_thetas.shape[0] == 1 + + +def test_wasserstein_circle_bad_shape(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + v = rng.rand(m, 1) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=2) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=1) diff --git a/test/test_backend.py b/test/test_backend.py index 3628f61..fd9a761 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -282,6 +282,20 @@ def test_empty_backend(): nx.array_equal(M, M) with pytest.raises(NotImplementedError): nx.is_floating_point(M) + with pytest.raises(NotImplementedError): + nx.tile(M, (10, 1)) + with pytest.raises(NotImplementedError): + nx.floor(M) + with pytest.raises(NotImplementedError): + nx.prod(M) + with pytest.raises(NotImplementedError): + nx.sort2(M) + with pytest.raises(NotImplementedError): + nx.qr(M) + with pytest.raises(NotImplementedError): + nx.atan2(v, v) + with pytest.raises(NotImplementedError): + nx.transpose(M) def test_func_backends(nx): @@ -603,6 +617,38 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("isfinite") + A = nx.tile(vb, (10, 1)) + lst_b.append(nx.to_numpy(A)) + lst_name.append("tile") + + A = nx.floor(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("floor") + + A = nx.prod(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("prod") + + A, B = nx.sort2(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("sort2 sort") + lst_b.append(nx.to_numpy(B)) + lst_name.append("sort2 argsort") + + A, B = nx.qr(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("QR Q") + lst_b.append(nx.to_numpy(B)) + lst_name.append("QR R") + + A = nx.atan2(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("atan2") + + A = nx.transpose(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("transpose") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( diff --git a/test/test_sliced.py b/test/test_sliced.py index eb13469..f54c799 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -266,3 +266,189 @@ def test_max_sliced_backend_device_tf(): valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") + + +def test_projections_stiefel(): + rng = np.random.RandomState(0) + + n_projs = 500 + x = np.random.randn(100, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + ssw, log = ot.sliced_wasserstein_sphere(x, x, n_projections=n_projs, + seed=rng, log=True) + + P = log["projections"] + P_T = np.transpose(P, [0, 2, 1]) + np.testing.assert_almost_equal(np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)])) + + +def test_sliced_sphere_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res = ot.sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_sliced_sphere_bad_shapes(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + +def test_sliced_sphere_values_on_the_sphere(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 4) + + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + +def test_sliced_sphere_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_sphere(x, y, u, u, 10, p=1, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[0] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_sphere_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + y = rng.randn(n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + res = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + assert res > 0. + + +def test_1d_sliced_sphere_equals_emd(): + n = 100 + m = 120 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + x_coords = (np.arctan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi) + a = rng.uniform(0, 1, n) + a /= a.sum() + + y = rng.randn(m, 2) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + y_coords = (np.arctan2(-y[:, 1], -y[:, 0]) + np.pi) / (2 * np.pi) + u = ot.utils.unif(m) + + res = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=2) + expected = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=2) + + res1 = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=1) + expected1 = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=1) + + np.testing.assert_almost_equal(res ** 2, expected) + np.testing.assert_almost_equal(res1, expected1, decimal=3) + + +@pytest.skip_backend("tf") +def test_sliced_sphere_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(2 * n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, yb = nx.from_numpy(x, y, type_as=tp) + + valb = ot.sliced_wasserstein_sphere(xb, yb) + + nx.assert_same_dtype_device(xb, valb) + + +def test_sliced_sphere_unif_values_on_the_sphere(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng) + + +def test_sliced_sphere_unif_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[0] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_sphere_unif_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + + valb = ot.sliced_wasserstein_sphere_unif(xb) + + nx.assert_same_dtype_device(xb, valb) diff --git a/test/test_utils.py b/test/test_utils.py index 666c157..31b12ef 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -330,3 +330,13 @@ def test_OTResult(): for at in lst_attributes: with pytest.raises(NotImplementedError): getattr(res, at) + + +def test_get_coordinate_circle(): + + u = np.random.rand(1, 100) + x1, y1 = np.cos(u * (2 * np.pi)), np.sin(u * (2 * np.pi)) + x = np.concatenate([x1, y1]).T + x_p = ot.utils.get_coordinate_circle(x) + + np.testing.assert_allclose(u[0], x_p) -- cgit v1.2.3 From a5930d3b3a446bf860d6dfacc1e17151fae1dd1d Mon Sep 17 00:00:00 2001 From: Cédric Vincent-Cuaz Date: Thu, 9 Mar 2023 14:21:33 +0100 Subject: [MRG] Semi-relaxed (fused) gromov-wasserstein divergence and improvements of gromov-wasserstein solvers (#431) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * maj gw/ srgw/ generic cg solver * correct pep8 on current state * fix bug previous tests * fix pep8 * fix bug srGW constC in loss and gradient * fix doc html * fix doc html * start updating test_optim.py * update tests gromov and optim - plus fix gromov dependencies * add symmetry feature to entropic gw * add symmetry feature to entropic gw * add exemple for sr(F)GW matchings * small stuff * remove (reg,M) from line-search/ complete srgw tests with backend * remove backend repetitions / rename fG to costG/ fix innerlog to True * fix pep8 * take comments into account / new nx parameters still to test * factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions * split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions * manual documentaion gromov * remove circular autosummary * trying stuff * debug documentation * alphabetic ordering of module * merge into branch * add note in entropic gw solvers --------- Co-authored-by: Rémi Flamary --- CONTRIBUTORS.md | 2 +- README.md | 7 +- RELEASES.md | 4 +- docs/cache_nbrun | 1 - docs/source/_templates/module.rst | 2 + docs/source/all.rst | 32 +- docs/source/conf.py | 5 +- examples/gromov/plot_gromov.py | 1 - examples/gromov/plot_semirelaxed_fgw.py | 300 ++++ ot/__init__.py | 1 - ot/bregman.py | 2 +- ot/dr.py | 2 +- ot/gromov.py | 2838 ------------------------------- ot/gromov/__init__.py | 48 + ot/gromov/_bregman.py | 351 ++++ ot/gromov/_dictionary.py | 1008 +++++++++++ ot/gromov/_estimators.py | 425 +++++ ot/gromov/_gw.py | 978 +++++++++++ ot/gromov/_semirelaxed.py | 543 ++++++ ot/gromov/_utils.py | 413 +++++ ot/lp/__init__.py | 2 +- ot/optim.py | 460 ++--- test/test_gromov.py | 621 ++++++- test/test_optim.py | 8 +- 24 files changed, 4964 insertions(+), 3090 deletions(-) delete mode 100644 docs/cache_nbrun create mode 100644 examples/gromov/plot_semirelaxed_fgw.py delete mode 100644 ot/gromov.py create mode 100644 ot/gromov/__init__.py create mode 100644 ot/gromov/_bregman.py create mode 100644 ot/gromov/_dictionary.py create mode 100644 ot/gromov/_estimators.py create mode 100644 ot/gromov/_gw.py create mode 100644 ot/gromov/_semirelaxed.py create mode 100644 ot/gromov/_utils.py (limited to 'CONTRIBUTORS.md') diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 1437821..6b35653 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -36,7 +36,7 @@ The contributors to this library are: * [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) -* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) +* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, semi-relaxed FGW) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) diff --git a/README.md b/README.md index d5e6854..e7241b8 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ POT provides the following generic OT solvers (links to examples): * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45] * [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] * [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. +* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) [48]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. POT provides the following Machine Learning related solvers: @@ -300,4 +301,8 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. -[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. \ No newline at end of file +[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. + +[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787. + +[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index bf2ce2e..b51409b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -3,7 +3,9 @@ ## 0.8.3dev #### New features - +- Added feature to (Fused) Gromov-Wasserstein solvers herited from `ot.optim` to support relative and absolute loss variations as stopping criterions (PR #431) +- Added feature to (Fused) Gromov-Wasserstein solvers to handle asymmetric matrices (PR #431) +- Added semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #431) - Added the spherical sliced-Wasserstein discrepancy in `ot.sliced.sliced_wasserstein_sphere` and `ot.sliced.sliced_wasserstein_sphere_unif` + examples (PR #434) - Added the Wasserstein distance on the circle in ``ot.lp.solver_1d.wasserstein_circle`` (PR #434) - Added the Wasserstein distance on the circle (for p>=1) in `ot.lp.solver_1d.binary_search_circle` + examples (PR #434) diff --git a/docs/cache_nbrun b/docs/cache_nbrun deleted file mode 100644 index ac49515..0000000 --- a/docs/cache_nbrun +++ /dev/null @@ -1 +0,0 @@ -{"plot_otda_color_images.ipynb": "128d0435c08ebcf788913e4adcd7dd00", "plot_partial_wass_and_gromov.ipynb": "82242f8390df1d04806b333b745c72cf", "plot_WDA.ipynb": "27f8de4c6d7db46497076523673eedfb", "plot_screenkhorn_1D.ipynb": "af7b8a74a1be0f16f2c3908f5a178de0", "plot_otda_laplacian.ipynb": "d92cc0e528b9277f550daaa6f9d18415", "plot_OT_L1_vs_L2.ipynb": "288230c4e679d752a511353c96c134cb", "plot_otda_semi_supervised.ipynb": "568b39ffbdf6621dd6de162df42f4f21", "plot_fgw.ipynb": "f4de8e6939ce2b1339b3badc1fef0f37", "plot_otda_d2.ipynb": "07ef3212ff3123f16c32a5670e0167f8", "plot_compute_emd.ipynb": "299f6fffcdbf48b7c3268c0136e284f8", "plot_barycenter_fgw.ipynb": "9e813d3b07b7c0c0fcc35a778ca1243f", "plot_convolutional_barycenter.ipynb": "fdd259bfcd6d5fe8001efb4345795d2f", "plot_optim_OTreg.ipynb": "bddd8e49f092873d8980d41ae4974e19", "plot_UOT_1D.ipynb": "2658d5164165941b07539dae3cb80a0f", "plot_OT_1D_smooth.ipynb": "f3e1f0e362c9a78071a40c02b85d2305", "plot_barycenter_1D.ipynb": "f6fa5bc13d9811f09792f73b4de70aa0", "plot_otda_mapping.ipynb": "1bb321763f670fc945d77cfc91471e5e", "plot_OT_1D.ipynb": "0346a8c862606d11f36d0aa087ecab0d", "plot_gromov_barycenter.ipynb": "a7999fcc236d90a0adeb8da2c6370db3", "plot_UOT_barycenter_1D.ipynb": "dd9b857a8c66d71d0124d4a2c30a51dd", "plot_otda_mapping_colors_images.ipynb": "16faae80d6ea8b37d6b1f702149a10de", "plot_stochastic.ipynb": "64f23a8dcbab9823ae92f0fd6c3aceab", "plot_otda_linear_mapping.ipynb": "82417d9141e310bf1f2c2ecdb550094b", "plot_otda_classes.ipynb": "8836a924c9b562ef397af12034fa1abb", "plot_free_support_barycenter.ipynb": "be9d0823f9d7774a289311b9f14548eb", "plot_gromov.ipynb": "de06b1dbe8de99abae51c2e0b64b485d", "plot_otda_jcpot.ipynb": "65482cbfef5c6c1e5e73998aeb5f4b10", "plot_OT_2D_samples.ipynb": "9a9496792fa4216b1059fc70abca851a", "plot_barycenter_lp_vs_entropic.ipynb": "334840b69a86898813e50a6db0f3d0de"} \ No newline at end of file diff --git a/docs/source/_templates/module.rst b/docs/source/_templates/module.rst index 5ad89be..495995e 100644 --- a/docs/source/_templates/module.rst +++ b/docs/source/_templates/module.rst @@ -2,6 +2,7 @@ {{ underline }} .. automodule:: {{ fullname }} + :members: {% block functions %} {% if functions %} @@ -12,6 +13,7 @@ {% for item in functions %} .. autofunction:: {{ item }} + .. include:: backreferences/{{fullname}}.{{item}}.examples diff --git a/docs/source/all.rst b/docs/source/all.rst index 60cc85c..41d8e06 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -13,29 +13,33 @@ API and modules :toctree: gen_modules/ :template: module.rst - lp + backend bregman - smooth - gromov - optim da - dr - utils datasets + dr + factored + gaussian + gromov + lp + optim + partial plot - stochastic - unbalanced regpath - partial sliced + smooth + stochastic + unbalanced + utils weak - factored - gaussian + -.. autosummary:: - :toctree: ../modules/generated/ - :template: module.rst +Main :py:mod:`ot` functions +-------------- .. automodule:: ot :members: + + + diff --git a/docs/source/conf.py b/docs/source/conf.py index 3bec150..6e76291 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -119,7 +119,7 @@ release = __version__ # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = "en" # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: @@ -341,6 +341,9 @@ texinfo_documents = [ # If true, do not generate a @detailmenu in the "Top" node's menu. #texinfo_no_detailmenu = False +autodoc_default_options = {'autosummary': True, + 'autosummary_imported_members': True} + # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = {'python': ('https://docs.python.org/3', None), diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index 5a362cf..05074dc 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -3,7 +3,6 @@ ========================== Gromov-Wasserstein example ========================== - This example is designed to show how to use the Gromov-Wassertsein distance computation in POT. """ diff --git a/examples/gromov/plot_semirelaxed_fgw.py b/examples/gromov/plot_semirelaxed_fgw.py new file mode 100644 index 0000000..8f879d4 --- /dev/null +++ b/examples/gromov/plot_semirelaxed_fgw.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- +""" +========================== +Semi-relaxed (Fused) Gromov-Wasserstein example +========================== + +This example is designed to show how to use the semi-relaxed Gromov-Wasserstein +and the semi-relaxed Fused Gromov-Wasserstein divergences. + +sr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of +G2 at a minimal (F)GW distance from G1. + +First, we generate two graphs following Stochastic Block Models, then show +how to compute their srGW matchings and illustrate them. These graphs are then +endowed with node features and we follow the same process with srFGW. + +[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. +"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" +International Conference on Learning Representations (ICLR), 2021. +""" + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +import numpy as np +import matplotlib.pylab as pl +from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein +import networkx +from networkx.generators.community import stochastic_block_model as sbm + +# %% +# ============================================================================= +# Generate two graphs following Stochastic Block models of 2 and 3 clusters. +# ============================================================================= + + +N2 = 20 # 2 communities +N3 = 30 # 3 communities +p2 = [[1., 0.1], + [0.1, 0.9]] +p3 = [[1., 0.1, 0.], + [0.1, 0.95, 0.1], + [0., 0.1, 0.9]] +G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2) +G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3) + + +C2 = networkx.to_numpy_array(G2) +C3 = networkx.to_numpy_array(G3) + +h2 = np.ones(C2.shape[0]) / C2.shape[0] +h3 = np.ones(C3.shape[0]) / C3.shape[0] + +# Add weights on the edges for visualization later on +weight_intra_G2 = 5 +weight_inter_G2 = 0.5 +weight_intra_G3 = 1. +weight_inter_G3 = 1.5 + +weightedG2 = networkx.Graph() +part_G2 = [G2.nodes[i]['block'] for i in range(N2)] + +for node in G2.nodes(): + weightedG2.add_node(node) +for i, j in G2.edges(): + if part_G2[i] == part_G2[j]: + weightedG2.add_edge(i, j, weight=weight_intra_G2) + else: + weightedG2.add_edge(i, j, weight=weight_inter_G2) + +weightedG3 = networkx.Graph() +part_G3 = [G3.nodes[i]['block'] for i in range(N3)] + +for node in G3.nodes(): + weightedG3.add_node(node) +for i, j in G3.edges(): + if part_G3[i] == part_G3[j]: + weightedG3.add_edge(i, j, weight=weight_intra_G3) + else: + weightedG3.add_edge(i, j, weight=weight_inter_G3) +# %% +# ============================================================================= +# Compute their semi-relaxed Gromov-Wasserstein divergences +# ============================================================================= + +# 0) GW(C2, h2, C3, h3) for reference +OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True) +gw = log['gw_dist'] + +# 1) srGW(C2, h2, C3) +OT_23, log_23 = semirelaxed_gromov_wasserstein(C2, C3, h2, symmetric=True, + log=True, G0=None) +srgw_23 = log_23['srgw_dist'] + +# 2) srGW(C3, h3, C2) + +OT_32, log_32 = semirelaxed_gromov_wasserstein(C3, C2, h3, symmetric=None, + log=True, G0=OT.T) +srgw_32 = log_32['srgw_dist'] + +print('GW(C2, C3) = ', gw) +print('srGW(C2, h2, C3) = ', srgw_23) +print('srGW(C3, h3, C2) = ', srgw_32) + + +# %% +# ============================================================================= +# Visualization of the semi-relaxed Gromov-Wasserstein matchings +# ============================================================================= + +# We color nodes of the graph on the right - then project its node colors +# based on the optimal transport plan from the srGW matching + + +def draw_graph(G, C, nodes_color_part, Gweights=None, + pos=None, edge_color='black', node_size=None, + shiftx=0, seed=0): + + if (pos is None): + pos = networkx.spring_layout(G, scale=1., seed=seed) + + if shiftx != 0: + for k, v in pos.items(): + v[0] = v[0] + shiftx + + alpha_edge = 0.7 + width_edge = 1.8 + if Gweights is None: + networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color) + else: + # We make more visible connections between activated nodes + n = len(Gweights) + edgelist_activated = [] + edgelist_deactivated = [] + for i in range(n): + for j in range(n): + if Gweights[i] * Gweights[j] * C[i, j] > 0: + edgelist_activated.append((i, j)) + elif C[i, j] > 0: + edgelist_deactivated.append((i, j)) + + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated, + width=width_edge, alpha=alpha_edge, + edge_color=edge_color) + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated, + width=width_edge, alpha=0.1, + edge_color=edge_color) + + if Gweights is None: + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=node_size, alpha=1, + node_color=node_color) + else: + scaled_Gweights = Gweights / (0.5 * Gweights.max()) + nodes_size = node_size * scaled_Gweights + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=nodes_size[node], alpha=1, + node_color=node_color) + return pos + + +def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, + p1, p2, T, pos1=None, pos2=None, + shiftx=4, switchx=False, node_size=70, + seed_G1=0, seed_G2=0): + starting_color = 0 + # get graphs partition and their coloring + part1 = part_G1.copy() + unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)] + nodes_color_part1 = [] + for cluster in part1: + nodes_color_part1.append(unique_colors[cluster]) + + nodes_color_part2 = [] + # T: getting colors assignment from argmin of columns + for i in range(len(G2.nodes())): + j = np.argmax(T[:, i]) + nodes_color_part2.append(nodes_color_part1[j]) + pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1, + pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1) + pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2, + node_size=node_size, shiftx=shiftx, seed=seed_G2) + for k1, v1 in pos1.items(): + for k2, v2 in pos2.items(): + if (T[k1, k2] > 0): + pl.plot([pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + '-', lw=0.8, alpha=0.5, + color=nodes_color_part1[k1]) + return pos1, pos2 + + +node_size = 40 +fontsize = 10 +seed_G2 = 0 +seed_G3 = 4 + +pl.figure(1, figsize=(8, 2.5)) +pl.clf() +pl.subplot(121) +pl.axis('off') +pl.axis +pl.title(r'srGW$(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize) + +hbar2 = OT_23.sum(axis=0) +pos1, pos2 = draw_transp_colored_srGW( + weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) +pl.subplot(122) +pl.axis('off') +hbar3 = OT_32.sum(axis=0) +pl.title(r'srGW$(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize) +pos1, pos2 = draw_transp_colored_srGW( + weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, + pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) +pl.tight_layout() + +pl.show() + +# %% +# ============================================================================= +# Add node features +# ============================================================================= + +# We add node features with given mean - by clusters +# and inversely proportional to clusters' intra-connectivity + +F2 = np.zeros((N2, 1)) +for i, c in enumerate(part_G2): + F2[i, 0] = np.random.normal(loc=c, scale=0.01) + +F3 = np.zeros((N3, 1)) +for i, c in enumerate(part_G3): + F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01) + +# %% +# ============================================================================= +# Compute their semi-relaxed Fused Gromov-Wasserstein divergences +# ============================================================================= + +alpha = 0.5 +# Compute pairwise euclidean distance between node features +M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T) + +# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference + +OT, log = fused_gromov_wasserstein( + M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True) +fgw = log['fgw_dist'] + +# 1) srFGW(C2, F2, h2, C3, F3) +OT_23, log_23 = semirelaxed_fused_gromov_wasserstein( + M, C2, C3, h2, symmetric=True, alpha=0.5, log=True, G0=None) +srfgw_23 = log_23['srfgw_dist'] + +# 2) srFGW(C3, F3, h3, C2, F2) + +OT_32, log_32 = semirelaxed_fused_gromov_wasserstein( + M.T, C3, C2, h3, symmetric=None, alpha=alpha, log=True, G0=None) +srfgw_32 = log_32['srfgw_dist'] + +print('FGW(C2, F2, C3, F3) = ', fgw) +print('srGW(C2, F2, h2, C3, F3) = ', srfgw_23) +print('srGW(C3, F3, h3, C2, F2) = ', srfgw_32) + +# %% +# ============================================================================= +# Visualization of the semi-relaxed Fused Gromov-Wasserstein matchings +# ============================================================================= + +# We color nodes of the graph on the right - then project its node colors +# based on the optimal transport plan from the srFGW matching +# NB: colors refer to clusters - not to node features + +pl.figure(2, figsize=(8, 2.5)) +pl.clf() +pl.subplot(121) +pl.axis('off') +pl.axis +pl.title(r'srFGW$(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize) + +hbar2 = OT_23.sum(axis=0) +pos1, pos2 = draw_transp_colored_srGW( + weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) +pl.subplot(122) +pl.axis('off') +hbar3 = OT_32.sum(axis=0) +pl.title(r'srFGW$(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize) +pos1, pos2 = draw_transp_colored_srGW( + weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, + pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) +pl.tight_layout() + +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 45d5cfa..0e36459 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -8,7 +8,6 @@ , :py:mod:`ot.unbalanced`. The following sub-modules are not imported due to additional dependencies: - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`. - - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU. - :any:`ot.plot` : depends on :code:`matplotlib` """ diff --git a/ot/bregman.py b/ot/bregman.py index 192a9e2..20bef7e 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -3048,7 +3048,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', M = nx.from_numpy(M, type_as=a) m1_cols.append( nx.sum(nx.exp(f[i:i + bs, None] + - g[None, :] - M / reg), axis=1) + g[None, :] - M / reg), axis=1) ) m1 = nx.concatenate(m1_cols, axis=0) err = nx.sum(nx.abs(m1 - a)) diff --git a/ot/dr.py b/ot/dr.py index 1b97841..b92cd14 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -167,7 +167,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter Size of dimensionnality reduction. reg : float, optional Regularization term >0 (entropic regularization) - solver : None | str, optional + solver : None | str, optional None for steepest descent or 'TrustRegions' for trust regions algorithm else should be a pymanopt.solvers sinkhorn_method : str diff --git a/ot/gromov.py b/ot/gromov.py deleted file mode 100644 index bc1c8e5..0000000 --- a/ot/gromov.py +++ /dev/null @@ -1,2838 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers -""" - -# Author: Erwan Vautier -# Nicolas Courty -# Rémi Flamary -# Titouan Vayer -# Cédric Vincent-Cuaz -# -# License: MIT License - -import numpy as np - - -from .bregman import sinkhorn -from .utils import dist, UndefinedParameter, list_to_array -from .optim import cg -from .lp import emd_1d, emd -from .utils import check_random_state, unif -from .backend import get_backend - - -def init_matrix(C1, C2, p, q, loss_fun='square_loss'): - r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation - - Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the - selected loss function as the loss function of Gromow-Wasserstein discrepancy. - - The matrices are computed as described in Proposition 1 in :ref:`[12] ` - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{T}`: A coupling between those two spaces - - The square-loss function :math:`L(a, b) = |a - b|^2` is read as : - - .. math:: - - L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) - - \mathrm{with} \ f_1(a) &= a^2 - - f_2(b) &= b^2 - - h_1(a) &= a - - h_2(b) &= 2b - - The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as : - - .. math:: - - L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) - - \mathrm{with} \ f_1(a) &= a \log(a) - a - - f_2(b) &= b - - h_1(a) &= a - - h_2(b) &= \log(b) - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Probability distribution in the source space - q : array-like, shape (nt,) - Probability distribution in the target space - loss_fun : str, optional - Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss') - - Returns - ------- - constC : array-like, shape (ns, nt) - Constant :math:`\mathbf{C}` matrix in Eq. (6) - hC1 : array-like, shape (ns, ns) - :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) - hC2 : array-like, shape (nt, nt) - :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) - - - .. _references-init-matrix: - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - C1, C2, p, q = list_to_array(C1, C2, p, q) - nx = get_backend(C1, C2, p, q) - - if loss_fun == 'square_loss': - def f1(a): - return (a**2) - - def f2(b): - return (b**2) - - def h1(a): - return a - - def h2(b): - return 2 * b - elif loss_fun == 'kl_loss': - def f1(a): - return a * nx.log(a + 1e-15) - a - - def f2(b): - return b - - def h1(a): - return a - - def h2(b): - return nx.log(b + 1e-15) - - constC1 = nx.dot( - nx.dot(f1(C1), nx.reshape(p, (-1, 1))), - nx.ones((1, len(q)), type_as=q) - ) - constC2 = nx.dot( - nx.ones((len(p), 1), type_as=p), - nx.dot(nx.reshape(q, (1, -1)), f2(C2).T) - ) - constC = constC1 + constC2 - hC1 = h1(C1) - hC2 = h2(C2) - - return constC, hC1, hC2 - - -def tensor_product(constC, hC1, hC2, T): - r"""Return the tensor for Gromov-Wasserstein fast computation - - The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` - - Parameters - ---------- - constC : array-like, shape (ns, nt) - Constant :math:`\mathbf{C}` matrix in Eq. (6) - hC1 : array-like, shape (ns, ns) - :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) - hC2 : array-like, shape (nt, nt) - :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) - - Returns - ------- - tens : array-like, shape (`ns`, `nt`) - :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result - - - .. _references-tensor-product: - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T) - nx = get_backend(constC, hC1, hC2, T) - - A = - nx.dot( - nx.dot(hC1, T), hC2.T - ) - tens = constC + A - # tens -= tens.min() - return tens - - -def gwloss(constC, hC1, hC2, T): - r"""Return the Loss for Gromov-Wasserstein - - The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` - - Parameters - ---------- - constC : array-like, shape (ns, nt) - Constant :math:`\mathbf{C}` matrix in Eq. (6) - hC1 : array-like, shape (ns, ns) - :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) - hC2 : array-like, shape (nt, nt) - :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) - T : array-like, shape (ns, nt) - Current value of transport matrix :math:`\mathbf{T}` - - Returns - ------- - loss : float - Gromov Wasserstein loss - - - .. _references-gwloss: - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - - tens = tensor_product(constC, hC1, hC2, T) - - tens, T = list_to_array(tens, T) - nx = get_backend(tens, T) - - return nx.sum(tens * T) - - -def gwggrad(constC, hC1, hC2, T): - r"""Return the gradient for Gromov-Wasserstein - - The gradient is computed as described in Proposition 2 in :ref:`[12] ` - - Parameters - ---------- - constC : array-like, shape (ns, nt) - Constant :math:`\mathbf{C}` matrix in Eq. (6) - hC1 : array-like, shape (ns, ns) - :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) - hC2 : array-like, shape (nt, nt) - :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) - T : array-like, shape (ns, nt) - Current value of transport matrix :math:`\mathbf{T}` - - Returns - ------- - grad : array-like, shape (`ns`, `nt`) - Gromov Wasserstein gradient - - - .. _references-gwggrad: - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - return 2 * tensor_product(constC, hC1, hC2, - T) # [12] Prop. 2 misses a 2 factor - - -def update_square_loss(p, lambdas, T, Cs): - r""" - Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` - couplings calculated at each iteration - - Parameters - ---------- - p : array-like, shape (N,) - Masses in the targeted barycenter. - lambdas : list of float - List of the `S` spaces' weights. - T : list of S array-like of shape (ns,N) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. - Cs : list of S array-like, shape(ns,ns) - Metric cost matrices. - - Returns - ---------- - C : array-like, shape (`nt`, `nt`) - Updated :math:`\mathbf{C}` matrix. - """ - T = list_to_array(*T) - Cs = list_to_array(*Cs) - p = list_to_array(p) - nx = get_backend(p, *T, *Cs) - - tmpsum = sum([ - lambdas[s] * nx.dot( - nx.dot(T[s].T, Cs[s]), - T[s] - ) for s in range(len(T)) - ]) - ppt = nx.outer(p, p) - - return tmpsum / ppt - - -def update_kl_loss(p, lambdas, T, Cs): - r""" - Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration - - - Parameters - ---------- - p : array-like, shape (N,) - Weights in the targeted barycenter. - lambdas : list of float - List of the `S` spaces' weights - T : list of S array-like of shape (ns,N) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. - Cs : list of S array-like, shape(ns,ns) - Metric cost matrices. - - Returns - ---------- - C : array-like, shape (`ns`, `ns`) - updated :math:`\mathbf{C}` matrix - """ - Cs = list_to_array(*Cs) - T = list_to_array(*T) - p = list_to_array(p) - nx = get_backend(p, *T, *Cs) - - tmpsum = sum([ - lambdas[s] * nx.dot( - nx.dot(T[s].T, Cs[s]), - T[s] - ) for s in range(len(T)) - ]) - ppt = nx.outer(p, p) - - return nx.exp(tmpsum / ppt) - - -def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs): - r""" - Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` - - The function solves the following optimization problem: - - .. math:: - \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{p}`: distribution in the source space - - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : str - loss function used for the solver either 'square_loss' or 'kl_loss' - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - armijo : bool, optional - If True the step of the line-search is found via an armijo research. Else closed form is used. - If there are convergence issues use False. - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. - **kwargs : dict - parameters can be directly passed to the ot.optim.cg solver - - Returns - ------- - T : array-like, shape (`ns`, `nt`) - Coupling between the two spaces that minimizes: - - :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}` - log : dict - Convergence information and loss. - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the - metric approach to object matching. Foundations of computational - mathematics 11.4 (2011): 417-487. - - """ - p, q = list_to_array(p, q) - p0, q0, C10, C20 = p, q, C1, C2 - if G0 is None: - nx = get_backend(p0, q0, C10, C20) - else: - G0_ = G0 - nx = get_backend(p0, q0, C10, C20, G0_) - p = nx.to_numpy(p) - q = nx.to_numpy(q) - C1 = nx.to_numpy(C10) - C2 = nx.to_numpy(C20) - - if G0 is None: - G0 = p[:, None] * q[None, :] - else: - G0 = nx.to_numpy(G0_) - # Check marginals of G0 - np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) - np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) - - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - - def f(G): - return gwloss(constC, hC1, hC2, G) - - def df(G): - return gwggrad(constC, hC1, hC2, G) - - if log: - res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10) - log['u'] = nx.from_numpy(log['u'], type_as=C10) - log['v'] = nx.from_numpy(log['v'], type_as=C10) - return nx.from_numpy(res, type_as=C10), log - else: - return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10) - - -def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs): - r""" - Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` - - The function solves the following optimization problem: - - .. math:: - GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{p}`: distribution in the source space - - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity - matrices - - Note that when using backends, this loss function is differentiable wrt the - marices and weights for quadratic loss using the gradients from [38]_. - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space. - q : array-like, shape (nt,) - Distribution in the target space. - loss_fun : str - loss function used for the solver either 'square_loss' or 'kl_loss' - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - armijo : bool, optional - If True the step of the line-search is found via an armijo research. Else closed form is used. - If there are convergence issues use False. - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. - - Returns - ------- - gw_dist : float - Gromov-Wasserstein distance - log : dict - convergence information and Coupling marix - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the - metric approach to object matching. Foundations of computational - mathematics 11.4 (2011): 417-487. - - .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online - Graph Dictionary Learning, International Conference on Machine Learning - (ICML), 2021. - - """ - p, q = list_to_array(p, q) - p0, q0, C10, C20 = p, q, C1, C2 - if G0 is None: - nx = get_backend(p0, q0, C10, C20) - else: - G0_ = G0 - nx = get_backend(p0, q0, C10, C20, G0_) - - p = nx.to_numpy(p) - q = nx.to_numpy(q) - C1 = nx.to_numpy(C10) - C2 = nx.to_numpy(C20) - - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - - if G0 is None: - G0 = p[:, None] * q[None, :] - else: - G0 = nx.to_numpy(G0_) - # Check marginals of G0 - np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) - np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) - - def f(G): - return gwloss(constC, hC1, hC2, G) - - def df(G): - return gwggrad(constC, hC1, hC2, G) - - T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - - T0 = nx.from_numpy(T, type_as=C10) - - log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10) - log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10) - log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10) - log_gw['T'] = T0 - - gw = log_gw['gw_dist'] - - if loss_fun == 'square_loss': - gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) - gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) - gC1 = nx.from_numpy(gC1, type_as=C10) - gC2 = nx.from_numpy(gC2, type_as=C10) - gw = nx.set_gradients(gw, (p0, q0, C10, C20), - (log_gw['u'] - nx.mean(log_gw['u']), - log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2)) - - if log: - return gw, log_gw - else: - return gw - - -def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs): - r""" - Computes the FGW transport between two graphs (see :ref:`[24] `) - - .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} - - \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} - - \mathbf{\gamma} &\geq 0 - - where : - - - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - - `L` is a loss function to account for the misfit between the similarity matrices - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` - - Parameters - ---------- - M : array-like, shape (ns, nt) - Metric cost matrix between features across domains - C1 : array-like, shape (ns, ns) - Metric cost matrix representative of the structure in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix representative of the structure in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : str, optional - Loss function used for the solver - alpha : float, optional - Trade-off parameter (0 < alpha < 1) - armijo : bool, optional - If True the step of the line-search is found via an armijo research. Else closed form is used. - If there are convergence issues use False. - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. - log : bool, optional - record log if True - **kwargs : dict - parameters can be directly passed to the ot.optim.cg solver - - Returns - ------- - gamma : array-like, shape (`ns`, `nt`) - Optimal transportation matrix for the given parameters. - log : dict - Log dictionary return only if log==True in parameters. - - - .. _references-fused-gromov-wasserstein: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain - and Courty Nicolas "Optimal Transport for structured data with - application on graphs", International Conference on Machine Learning - (ICML). 2019. - """ - p, q = list_to_array(p, q) - p0, q0, C10, C20, M0 = p, q, C1, C2, M - if G0 is None: - nx = get_backend(p0, q0, C10, C20, M0) - else: - G0_ = G0 - nx = get_backend(p0, q0, C10, C20, M0, G0_) - - p = nx.to_numpy(p) - q = nx.to_numpy(q) - C1 = nx.to_numpy(C10) - C2 = nx.to_numpy(C20) - M = nx.to_numpy(M0) - if G0 is None: - G0 = p[:, None] * q[None, :] - else: - G0 = nx.to_numpy(G0_) - # Check marginals of G0 - np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) - np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) - - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - - def f(G): - return gwloss(constC, hC1, hC2, G) - - def df(G): - return gwggrad(constC, hC1, hC2, G) - - if log: - res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) - fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10) - log['fgw_dist'] = fgw_dist - log['u'] = nx.from_numpy(log['u'], type_as=C10) - log['v'] = nx.from_numpy(log['v'], type_as=C10) - return nx.from_numpy(res, type_as=C10), log - else: - return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10) - - -def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs): - r""" - Computes the FGW distance between two graphs see (see :ref:`[24] `) - - .. math:: - \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} - - \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} - - \mathbf{\gamma} &\geq 0 - - where : - - - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - - `L` is a loss function to account for the misfit between the similarity matrices - - The algorithm used for solving the problem is conditional gradient as - discussed in :ref:`[24] ` - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. But the algorithm uses the C++ CPU backend - which can lead to copy overhead on GPU arrays. - - Note that when using backends, this loss function is differentiable wrt the - marices and weights for quadratic loss using the gradients from [38]_. - - Parameters - ---------- - M : array-like, shape (ns, nt) - Metric cost matrix between features across domains - C1 : array-like, shape (ns, ns) - Metric cost matrix representative of the structure in the source space. - C2 : array-like, shape (nt, nt) - Metric cost matrix representative of the structure in the target space. - p : array-like, shape (ns,) - Distribution in the source space. - q : array-like, shape (nt,) - Distribution in the target space. - loss_fun : str, optional - Loss function used for the solver. - alpha : float, optional - Trade-off parameter (0 < alpha < 1) - armijo : bool, optional - If True the step of the line-search is found via an armijo research. - Else closed form is used. If there are convergence issues use False. - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. - log : bool, optional - Record log if True. - **kwargs : dict - Parameters can be directly passed to the ot.optim.cg solver. - - Returns - ------- - fgw-distance : float - Fused gromov wasserstein distance for the given parameters. - log : dict - Log dictionary return only if log==True in parameters. - - - .. _references-fused-gromov-wasserstein2: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain - and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - - .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online - Graph Dictionary Learning, International Conference on Machine Learning - (ICML), 2021. - """ - p, q = list_to_array(p, q) - - p0, q0, C10, C20, M0 = p, q, C1, C2, M - if G0 is None: - nx = get_backend(p0, q0, C10, C20, M0) - else: - G0_ = G0 - nx = get_backend(p0, q0, C10, C20, M0, G0_) - - p = nx.to_numpy(p) - q = nx.to_numpy(q) - C1 = nx.to_numpy(C10) - C2 = nx.to_numpy(C20) - M = nx.to_numpy(M0) - - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - - if G0 is None: - G0 = p[:, None] * q[None, :] - else: - G0 = nx.to_numpy(G0_) - # Check marginals of G0 - np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) - np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) - - def f(G): - return gwloss(constC, hC1, hC2, G) - - def df(G): - return gwggrad(constC, hC1, hC2, G) - - T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) - - fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10) - - T0 = nx.from_numpy(T, type_as=C10) - - log_fgw['fgw_dist'] = fgw_dist - log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10) - log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10) - log_fgw['T'] = T0 - - if loss_fun == 'square_loss': - gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) - gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) - gC1 = nx.from_numpy(gC1, type_as=C10) - gC2 = nx.from_numpy(gC2, type_as=C10) - fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), - (log_fgw['u'] - nx.mean(log_fgw['u']), - log_fgw['v'] - nx.mean(log_fgw['v']), - alpha * gC1, alpha * gC2, (1 - alpha) * T0)) - - if log: - return fgw_dist, log_fgw - else: - return fgw_dist - - -def GW_distance_estimation(C1, C2, p, q, loss_fun, T, - nb_samples_p=None, nb_samples_q=None, std=True, random_state=None): - r""" - Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` - with a fixed transport plan :math:`\mathbf{T}`. - - The function gives an unbiased approximation of the following equation: - - .. math:: - - GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - `L` : Loss function to account for the misfit between the similarity matrices - - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}` - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` - Loss function used for the distance, the transport plan does not depend on the loss function - T : csr or array-like, shape (ns, nt) - Transport plan matrix, either a sparse csr or a dense matrix - nb_samples_p : int, optional - `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}` - nb_samples_q : int, optional - `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first - std : bool, optional - Standard deviation associated with the prediction of the gromov-wasserstein cost - random_state : int or RandomState instance, optional - Fix the seed for reproducibility - - Returns - ------- - : float - Gromov-wasserstein cost - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ - C1, C2, p, q = list_to_array(C1, C2, p, q) - nx = get_backend(C1, C2, p, q) - - generator = check_random_state(random_state) - - len_p = p.shape[0] - len_q = q.shape[0] - - # It is always better to sample from the biggest distribution first. - if len_p < len_q: - p, q = q, p - len_p, len_q = len_q, len_p - C1, C2 = C2, C1 - T = T.T - - if nb_samples_p is None: - if nx.issparse(T): - # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced - nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p) - else: - nb_samples_p = len_p - else: - # The number of sample along the first dimension is without replacement. - nb_samples_p = min(nb_samples_p, len_p) - if nb_samples_q is None: - nb_samples_q = 1 - if std: - nb_samples_q = max(2, nb_samples_q) - - index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int) - index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) - - index_i = generator.choice( - len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False - ) - index_j = generator.choice( - len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False - ) - - for i in range(nb_samples_p): - if nx.issparse(T): - T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,)) - T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,)) - else: - T_indexi = T[index_i[i], :] - T_indexj = T[index_j[i], :] - # For each of the row sampled, the column is sampled. - index_k[i] = generator.choice( - len_q, - size=nb_samples_q, - p=nx.to_numpy(T_indexi / nx.sum(T_indexi)), - replace=True - ) - index_l[i] = generator.choice( - len_q, - size=nb_samples_q, - p=nx.to_numpy(T_indexj / nx.sum(T_indexj)), - replace=True - ) - - list_value_sample = nx.stack([ - loss_fun( - C1[np.ix_(index_i, index_j)], - C2[np.ix_(index_k[:, n], index_l[:, n])] - ) for n in range(nb_samples_q) - ], axis=2) - - if std: - std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5 - return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) - else: - return nx.mean(list_value_sample) - - -def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, - alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None): - r""" - Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe. - This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations. - - The function solves the following optimization problem: - - .. math:: - \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - - \mathbf{T}^T \mathbf{1} &= \mathbf{q} - - \mathbf{T} &\geq 0 - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{p}`: distribution in the source space - - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` - Loss function used for the distance, the transport plan does not depend on the loss function - alpha : float - Step of the Frank-Wolfe algorithm, should be between 0 and 1 - max_iter : int, optional - Max number of iterations - threshold_plan : float, optional - Deleting very small values in the transport plan. If above zero, it violates the marginal constraints. - verbose : bool, optional - Print information along iterations - log : bool, optional - Gives the distance estimated and the standard deviation - random_state : int or RandomState instance, optional - Fix the seed for reproducibility - - Returns - ------- - T : array-like, shape (`ns`, `nt`) - Optimal coupling between the two spaces - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ - C1, C2, p, q = list_to_array(C1, C2, p, q) - nx = get_backend(C1, C2, p, q) - - len_p = p.shape[0] - len_q = q.shape[0] - - generator = check_random_state(random_state) - - index = np.zeros(2, dtype=int) - - # Initialize with default marginal - index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) - index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q)) - T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)) - - best_gw_dist_estimated = np.inf - for cpt in range(max_iter): - index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) - T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,)) - index[1] = generator.choice( - len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0)) - ) - - if alpha == 1: - T = nx.tocsr( - emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) - ) - else: - new_T = nx.tocsr( - emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) - ) - T = (1 - alpha) * T + alpha * new_T - # To limit the number of non 0, the values below the threshold are set to 0. - T = nx.eliminate_zeros(T, threshold=threshold_plan) - - if cpt % 10 == 0 or cpt == (max_iter - 1): - gw_dist_estimated = GW_distance_estimation( - C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T, std=False, random_state=generator - ) - - if gw_dist_estimated < best_gw_dist_estimated: - best_gw_dist_estimated = gw_dist_estimated - best_T = nx.copy(T) - - if verbose: - if cpt % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated)) - - if log: - log = {} - log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( - C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=best_T, random_state=generator - ) - return best_T, log - return best_T - - -def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, - nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False, - random_state=None): - r""" - Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe. - This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver. - - The function solves the following optimization problem: - - .. math:: - \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - - s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - - \mathbf{T}^T \mathbf{1} &= \mathbf{q} - - \mathbf{T} &\geq 0 - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{p}`: distribution in the source space - - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` - Loss function used for the distance, the transport plan does not depend on the loss function - nb_samples_grad : int - Number of samples to approximate the gradient - epsilon : float - Weight of the Kullback-Leibler regularization - max_iter : int, optional - Max number of iterations - verbose : bool, optional - Print information along iterations - log : bool, optional - Gives the distance estimated and the standard deviation - random_state : int or RandomState instance, optional - Fix the seed for reproducibility - - Returns - ------- - T : array-like, shape (`ns`, `nt`) - Optimal coupling between the two spaces - - References - ---------- - .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc - "Sampled Gromov Wasserstein." - Machine Learning Journal (MLJ). 2021. - - """ - C1, C2, p, q = list_to_array(C1, C2, p, q) - nx = get_backend(C1, C2, p, q) - - len_p = p.shape[0] - len_q = q.shape[0] - - generator = check_random_state(random_state) - - # The most natural way to define nb_sample is with a simple integer. - if isinstance(nb_samples_grad, int): - if nb_samples_grad > len_p: - # As the sampling along the first dimension is done without replacement, the rest is reported to the second - # dimension. - nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p - else: - nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1 - else: - nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad - T = nx.outer(p, q) - # continue_loop allows to stop the loop if there is several successive small modification of T. - continue_loop = 0 - - # The gradient of GW is more complex if the two matrices are not symmetric. - C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) - - for cpt in range(max_iter): - index0 = generator.choice( - len_p, size=nb_samples_grad_p, p=nx.to_numpy(p), replace=False - ) - Lik = 0 - for i, index0_i in enumerate(index0): - index1 = generator.choice( - len_q, size=nb_samples_grad_q, - p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])), - replace=False - ) - # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. - if (not C_are_symmetric) and generator.rand(1) > 0.5: - Lik += nx.mean(loss_fun( - C1[:, [index0[i]] * nb_samples_grad_q][:, None, :], - C2[:, index1][None, :, :] - ), axis=2) - else: - Lik += nx.mean(loss_fun( - C1[[index0[i]] * nb_samples_grad_q, :][:, :, None], - C2[index1, :][:, None, :] - ), axis=0) - - max_Lik = nx.max(Lik) - if max_Lik == 0: - continue - # This division by the max is here to facilitate the choice of epsilon. - Lik /= max_Lik - - if epsilon > 0: - # Set to infinity all the numbers below exp(-200) to avoid log of 0. - log_T = nx.log(nx.clip(T, np.exp(-200), 1)) - log_T = nx.where(log_T == -200, -np.inf, log_T) - Lik = Lik - epsilon * log_T - - try: - new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon) - except (RuntimeWarning, UserWarning): - print("Warning catched in Sinkhorn: Return last stable T") - break - else: - new_T = emd(a=p, b=q, M=Lik) - - change_T = nx.mean((T - new_T) ** 2) - if change_T <= 10e-20: - continue_loop += 1 - if continue_loop > 100: # Number max of low modifications of T - T = nx.copy(new_T) - break - else: - continue_loop = 0 - - if verbose and cpt % 10 == 0: - if cpt % 200 == 0: - print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, change_T)) - T = nx.copy(new_T) - - if log: - log = {} - log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( - C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T, random_state=generator - ) - return T, log - return T - - -def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, - max_iter=1000, tol=1e-9, verbose=False, log=False): - r""" - Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` - - The function solves the following optimization problem: - - .. math:: - \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) - - s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - - \mathbf{T}^T \mathbf{1} &= \mathbf{q} - - \mathbf{T} &\geq 0 - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{p}`: distribution in the source space - - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices - - `H`: entropy - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : string - Loss function used for the solver either 'square_loss' or 'kl_loss' - epsilon : float - Regularization term >0 - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - Record log if True. - - Returns - ------- - T : array-like, shape (`ns`, `nt`) - Optimal coupling between the two spaces - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - C1, C2, p, q = list_to_array(C1, C2, p, q) - nx = get_backend(C1, C2, p, q) - - T = nx.outer(p, q) - - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - - cpt = 0 - err = 1 - - if log: - log = {'err': []} - - while (err > tol and cpt < max_iter): - - Tprev = T - - # compute the gradient - tens = gwggrad(constC, hC1, hC2, T) - - T = sinkhorn(p, q, tens, epsilon, method='sinkhorn') - - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err = nx.norm(T - Tprev) - - if log: - log['err'].append(err) - - if verbose: - if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - cpt += 1 - - if log: - log['gw_dist'] = gwloss(constC, hC1, hC2, T) - return T, log - else: - return T - - -def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, - max_iter=1000, tol=1e-9, verbose=False, log=False): - r""" - Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` - - The function solves the following optimization problem: - - .. math:: - GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) - \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) - - Where : - - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{p}`: distribution in the source space - - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices - - `H`: entropy - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : str - Loss function used for the solver either 'square_loss' or 'kl_loss' - epsilon : float - Regularization term >0 - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - Record log if True. - - Returns - ------- - gw_dist : float - Gromov-Wasserstein distance - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - gw, logv = entropic_gromov_wasserstein( - C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True) - - logv['T'] = gw - - if log: - return logv['gw_dist'], logv - else: - return logv['gw_dist'] - - -def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, - max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None): - r""" - Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` - - The function solves the following optimization problem: - - .. math:: - - \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) - - Where : - - - :math:`\mathbf{C}_s`: metric cost matrix - - :math:`\mathbf{p}_s`: distribution - - Parameters - ---------- - N : int - Size of the targeted barycenter - Cs : list of S array-like of shape (ns,ns) - Metric cost matrices - ps : list of S array-like of shape (ns,) - Sample weights in the `S` spaces - p : array-like, shape(N,) - Weights in the targeted barycenter - lambdas : list of float - List of the `S` spaces' weights. - loss_fun : callable - Tensor-matrix multiplication function based on specific loss function. - update : callable - function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates - :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings - calculated at each iteration - epsilon : float - Regularization term >0 - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations. - log : bool, optional - Record log if True. - init_C : bool | array-like, shape (N, N) - Random initial value for the :math:`\mathbf{C}` matrix provided by user. - random_state : int or RandomState instance, optional - Fix the seed for reproducibility - - Returns - ------- - C : array-like, shape (`N`, `N`) - Similarity matrix in the barycenter space (permutated arbitrarily) - log : dict - Log dictionary of error during iterations. Return only if `log=True` in parameters. - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - """ - Cs = list_to_array(*Cs) - ps = list_to_array(*ps) - p = list_to_array(p) - nx = get_backend(*Cs, *ps, p) - - S = len(Cs) - - # Initialization of C : random SPD matrix (if not provided by user) - if init_C is None: - generator = check_random_state(random_state) - xalea = generator.randn(N, 2) - C = dist(xalea, xalea) - C /= C.max() - C = nx.from_numpy(C, type_as=p) - else: - C = init_C - - cpt = 0 - err = 1 - - error = [] - - while (err > tol) and (cpt < max_iter): - Cprev = C - - T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, - max_iter, 1e-4, verbose, log=False) for s in range(S)] - if loss_fun == 'square_loss': - C = update_square_loss(p, lambdas, T, Cs) - - elif loss_fun == 'kl_loss': - C = update_kl_loss(p, lambdas, T, Cs) - - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err = nx.norm(C - Cprev) - error.append(err) - - if verbose: - if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - cpt += 1 - - if log: - return C, {"err": error} - else: - return C - - -def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, - max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None): - r""" - Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` - - The function solves the following optimization problem with block coordinate descent: - - .. math:: - - \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) - - Where : - - - :math:`\mathbf{C}_s`: metric cost matrix - - :math:`\mathbf{p}_s`: distribution - - Parameters - ---------- - N : int - Size of the targeted barycenter - Cs : list of S array-like of shape (ns, ns) - Metric cost matrices - ps : list of S array-like of shape (ns,) - Sample weights in the `S` spaces - p : array-like, shape (N,) - Weights in the targeted barycenter - lambdas : list of float - List of the `S` spaces' weights - loss_fun : callable - tensor-matrix multiplication function based on specific loss function - update : callable - function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates - :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings - calculated at each iteration - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0). - verbose : bool, optional - Print information along iterations. - log : bool, optional - Record log if True. - init_C : bool | array-like, shape(N,N) - Random initial value for the :math:`\mathbf{C}` matrix provided by user. - random_state : int or RandomState instance, optional - Fix the seed for reproducibility - - Returns - ------- - C : array-like, shape (`N`, `N`) - Similarity matrix in the barycenter space (permutated arbitrarily) - log : dict - Log dictionary of error during iterations. Return only if `log=True` in parameters. - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - Cs = list_to_array(*Cs) - ps = list_to_array(*ps) - p = list_to_array(p) - nx = get_backend(*Cs, *ps, p) - - S = len(Cs) - - # Initialization of C : random SPD matrix (if not provided by user) - if init_C is None: - generator = check_random_state(random_state) - xalea = generator.randn(N, 2) - C = dist(xalea, xalea) - C /= C.max() - C = nx.from_numpy(C, type_as=p) - else: - C = init_C - - cpt = 0 - err = 1 - - error = [] - - while (err > tol and cpt < max_iter): - Cprev = C - - T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, - numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=False) for s in range(S)] - if loss_fun == 'square_loss': - C = update_square_loss(p, lambdas, T, Cs) - - elif loss_fun == 'kl_loss': - C = update_kl_loss(p, lambdas, T, Cs) - - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err = nx.norm(C - Cprev) - error.append(err) - - if verbose: - if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - cpt += 1 - - if log: - return C, {"err": error} - else: - return C - - -def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, - p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, - verbose=False, log=False, init_C=None, init_X=None, random_state=None): - r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` - - Parameters - ---------- - N : int - Desired number of samples of the target barycenter - Ys: list of array-like, each element has shape (ns,d) - Features of all samples - Cs : list of array-like, each element has shape (ns,ns) - Structure matrices of all samples - ps : list of array-like, each element has shape (ns,) - Masses of all samples. - lambdas : list of float - List of the `S` spaces' weights - alpha : float - Alpha parameter for the fgw distance - fixed_structure : bool - Whether to fix the structure of the barycenter during the updates - fixed_features : bool - Whether to fix the feature of the barycenter during the updates - loss_fun : str - Loss function used for the solver either 'square_loss' or 'kl_loss' - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0). - verbose : bool, optional - Print information along iterations. - log : bool, optional - Record log if True. - init_C : array-like, shape (N,N), optional - Initialization for the barycenters' structure matrix. If not set - a random init is used. - init_X : array-like, shape (N,d), optional - Initialization for the barycenters' features. If not set a - random init is used. - random_state : int or RandomState instance, optional - Fix the seed for reproducibility - - Returns - ------- - X : array-like, shape (`N`, `d`) - Barycenters' features - C : array-like, shape (`N`, `N`) - Barycenters' structure matrix - log : dict - Only returned when log=True. It contains the keys: - - - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices - - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`) - - - .. _references-fgw-barycenters: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain - and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - """ - Cs = list_to_array(*Cs) - ps = list_to_array(*ps) - Ys = list_to_array(*Ys) - p = list_to_array(p) - nx = get_backend(*Cs, *Ys, *ps) - - S = len(Cs) - d = Ys[0].shape[1] # dimension on the node features - if p is None: - p = nx.ones(N, type_as=Cs[0]) / N - - if fixed_structure: - if init_C is None: - raise UndefinedParameter('If C is fixed it must be initialized') - else: - C = init_C - else: - if init_C is None: - generator = check_random_state(random_state) - xalea = generator.randn(N, 2) - C = dist(xalea, xalea) - C = nx.from_numpy(C, type_as=ps[0]) - else: - C = init_C - - if fixed_features: - if init_X is None: - raise UndefinedParameter('If X is fixed it must be initialized') - else: - X = init_X - else: - if init_X is None: - X = nx.zeros((N, d), type_as=ps[0]) - else: - X = init_X - - T = [nx.outer(p, q) for q in ps] - - Ms = [dist(X, Ys[s]) for s in range(len(Ys))] - - cpt = 0 - err_feature = 1 - err_structure = 1 - - if log: - log_ = {} - log_['err_feature'] = [] - log_['err_structure'] = [] - log_['Ts_iter'] = [] - - while ((err_feature > tol or err_structure > tol) and cpt < max_iter): - Cprev = C - Xprev = X - - if not fixed_features: - Ys_temp = [y.T for y in Ys] - X = update_feature_matrix(lambdas, Ys_temp, T, p).T - - Ms = [dist(X, Ys[s]) for s in range(len(Ys))] - - if not fixed_structure: - if loss_fun == 'square_loss': - T_temp = [t.T for t in T] - C = update_structure_matrix(p, lambdas, T_temp, Cs) - - T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, - numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] - - # T is N,ns - err_feature = nx.norm(X - nx.reshape(Xprev, (N, d))) - err_structure = nx.norm(C - Cprev) - if log: - log_['err_feature'].append(err_feature) - log_['err_structure'].append(err_structure) - log_['Ts_iter'].append(T) - - if verbose: - if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err_structure)) - print('{:5d}|{:8e}|'.format(cpt, err_feature)) - - cpt += 1 - - if log: - log_['T'] = T # from target to Ys - log_['p'] = p - log_['Ms'] = Ms - - if log: - return X, C, log_ - else: - return X, C - - -def update_structure_matrix(p, lambdas, T, Cs): - r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. - - It is calculated at each iteration - - Parameters - ---------- - p : array-like, shape (N,) - Masses in the targeted barycenter. - lambdas : list of float - List of the `S` spaces' weights. - T : list of S array-like of shape (ns, N) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. - Cs : list of S array-like, shape (ns, ns) - Metric cost matrices. - - Returns - ------- - C : array-like, shape (`nt`, `nt`) - Updated :math:`\mathbf{C}` matrix. - """ - p = list_to_array(p) - T = list_to_array(*T) - Cs = list_to_array(*Cs) - nx = get_backend(*Cs, *T, p) - - tmpsum = sum([ - lambdas[s] * nx.dot( - nx.dot(T[s].T, Cs[s]), - T[s] - ) for s in range(len(T)) - ]) - ppt = nx.outer(p, p) - return tmpsum / ppt - - -def update_feature_matrix(lambdas, Ys, Ts, p): - r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. - - - See "Solving the barycenter problem with Block Coordinate Descent (BCD)" - in :ref:`[24] ` calculated at each iteration - - Parameters - ---------- - p : array-like, shape (N,) - masses in the targeted barycenter - lambdas : list of float - List of the `S` spaces' weights - Ts : list of S array-like, shape (ns,N) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration - Ys : list of S array-like, shape (d,ns) - The features. - - Returns - ------- - X : array-like, shape (`d`, `N`) - - - .. _references-update-feature-matrix: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - """ - p = list_to_array(p) - Ts = list_to_array(*Ts) - Ys = list_to_array(*Ys) - nx = get_backend(*Ys, *Ts, p) - - p = 1. / p - tmpsum = sum([ - lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :] - for s in range(len(Ts)) - ]) - return tmpsum - - -def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True, - tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): - r""" - Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s` - - .. math:: - \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2 - - such that, :math:`\forall s \leq S` : - - - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` - - :math:`\mathbf{w_s} \geq \mathbf{0}_D` - - Where : - - - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. - - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. - - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` - - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. - - reg is the regularization coefficient. - - The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38] - - Parameters - ---------- - Cs : list of S symmetric array-like, shape (ns, ns) - List of Metric/Graph cost matrices of variable size (ns, ns). - D: int - Number of dictionary atoms to learn - nt: int - Number of samples within each dictionary atoms - reg : float, optional - Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. - ps : list of S array-like, shape (ns,), optional - Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. - q : array-like, shape (nt,), optional - Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. - epochs: int, optional - Number of epochs used to learn the dictionary. Default is 32. - batch_size: int, optional - Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. - learning_rate: float, optional - Learning rate used for the stochastic gradient descent. Default is 1. - Cdict_init: list of D array-like with shape (nt, nt), optional - Used to initialize the dictionary. - If set to None (Default), the dictionary will be initialized randomly. - Else Cdict must have shape (D, nt, nt) i.e match provided shape features. - projection: str , optional - If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary - Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' - log: bool, optional - If set to True, losses evolution by batches and epochs are tracked. Default is False. - use_adam_optimizer: bool, optional - If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. - Else perform SGD with fixed learning rate. Default is True. - tol_outer : float, optional - Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. - tol_inner : float, optional - Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. - max_iter_outer : int, optional - Maximum number of iterations for the BCD. Default is 20. - max_iter_inner : int, optional - Maximum number of iterations for the Conjugate Gradient. Default is 200. - verbose : bool, optional - Print the reconstruction loss every epoch. Default is False. - - Returns - ------- - - Cdict_best_state : D array-like, shape (D,nt,nt) - Metric/Graph cost matrices composing the dictionary. - The dictionary leading to the best loss over an epoch is saved and returned. - log: dict - If use_log is True, contains loss evolutions by batches and epochs. - References - ------- - - ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. - "Online Graph Dictionary Learning" - International Conference on Machine Learning (ICML). 2021. - """ - # Handle backend of non-optional arguments - Cs0 = Cs - nx = get_backend(*Cs0) - Cs = [nx.to_numpy(C) for C in Cs0] - dataset_size = len(Cs) - # Handle backend of optional arguments - if ps is None: - ps = [unif(C.shape[0]) for C in Cs] - else: - ps = [nx.to_numpy(p) for p in ps] - if q is None: - q = unif(nt) - else: - q = nx.to_numpy(q) - if Cdict_init is None: - # Initialize randomly structures of dictionary atoms based on samples - dataset_means = [C.mean() for C in Cs] - Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) - else: - Cdict = nx.to_numpy(Cdict_init).copy() - assert Cdict.shape == (D, nt, nt) - - if 'symmetric' in projection: - Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) - if 'nonnegative' in projection: - Cdict[Cdict < 0.] = 0 - if use_adam_optimizer: - adam_moments = _initialize_adam_optimizer(Cdict) - - log = {'loss_batches': [], 'loss_epochs': []} - const_q = q[:, None] * q[None, :] - Cdict_best_state = Cdict.copy() - loss_best_state = np.inf - if batch_size > dataset_size: - batch_size = dataset_size - iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) - - for epoch in range(epochs): - cumulated_loss_over_epoch = 0. - - for _ in range(iter_by_epoch): - # batch sampling - batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) - cumulated_loss_over_batch = 0. - unmixings = np.zeros((batch_size, D)) - Cs_embedded = np.zeros((batch_size, nt, nt)) - Ts = [None] * batch_size - - for batch_idx, C_idx in enumerate(batch): - # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch - unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing( - Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner, - max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner - ) - cumulated_loss_over_batch += current_loss - cumulated_loss_over_epoch += cumulated_loss_over_batch - - if use_log: - log['loss_batches'].append(cumulated_loss_over_batch) - - # Stochastic projected gradient step over dictionary atoms - grad_Cdict = np.zeros_like(Cdict) - for batch_idx, C_idx in enumerate(batch): - shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) - grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] - grad_Cdict *= 2 / batch_size - if use_adam_optimizer: - Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments) - else: - Cdict -= learning_rate * grad_Cdict - if 'symmetric' in projection: - Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) - if 'nonnegative' in projection: - Cdict[Cdict < 0.] = 0. - - if use_log: - log['loss_epochs'].append(cumulated_loss_over_epoch) - if loss_best_state > cumulated_loss_over_epoch: - loss_best_state = cumulated_loss_over_epoch - Cdict_best_state = Cdict.copy() - if verbose: - print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) - - return nx.from_numpy(Cdict_best_state), log - - -def _initialize_adam_optimizer(variable): - - # Initialization for our numpy implementation of adam optimizer - atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor - atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor - atoms_adam_count = 1 - - return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count} - - -def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09): - - adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad - adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2) - unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count']) - unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count']) - variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps) - adam_moments['count'] += 1 - - return variable, adam_moments - - -def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs): - r""" - Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`. - - .. math:: - \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 - - such that: - - - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` - - :math:`\mathbf{w} \geq \mathbf{0}_D` - - Where : - - - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. - - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt. - - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. - - reg is the regularization coefficient. - - The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 1. - - Parameters - ---------- - C : array-like, shape (ns, ns) - Metric/Graph cost matrix. - Cdict : D array-like, shape (D,nt,nt) - Metric/Graph cost matrices composing the dictionary on which to embed C. - reg : float, optional. - Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. - p : array-like, shape (ns,), optional - Distribution in the source space C. Default is None and corresponds to uniform distribution. - q : array-like, shape (nt,), optional - Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. - tol_outer : float, optional - Solver precision for the BCD algorithm. - tol_inner : float, optional - Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. - max_iter_outer : int, optional - Maximum number of iterations for the BCD. Default is 20. - max_iter_inner : int, optional - Maximum number of iterations for the Conjugate Gradient. Default is 200. - - Returns - ------- - w: array-like, shape (D,) - gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary. - Cembedded: array-like, shape (nt,nt) - embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. - T: array-like (ns, nt) - Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})` - current_loss: float - reconstruction error - References - ------- - - ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. - "Online Graph Dictionary Learning" - International Conference on Machine Learning (ICML). 2021. - """ - C0, Cdict0 = C, Cdict - nx = get_backend(C0, Cdict0) - C = nx.to_numpy(C0) - Cdict = nx.to_numpy(Cdict0) - if p is None: - p = unif(C.shape[0]) - else: - p = nx.to_numpy(p) - - if q is None: - q = unif(Cdict.shape[-1]) - else: - q = nx.to_numpy(q) - - T = p[:, None] * q[None, :] - D = len(Cdict) - - w = unif(D) # Initialize uniformly the unmixing w - Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) - - const_q = q[:, None] * q[None, :] - # Trackers for BCD convergence - convergence_criterion = np.inf - current_loss = 10**15 - outer_count = 0 - - while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): - previous_loss = current_loss - # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w - T, log = gromov_wasserstein(C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, log=True, armijo=False, **kwargs) - current_loss = log['gw_dist'] - if reg != 0: - current_loss -= reg * np.sum(w**2) - - # 2. Solve linear unmixing problem over w with a fixed transport plan T - w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing( - C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T, - starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs - ) - - if previous_loss != 0: - convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) - else: # handle numerical issues around 0 - convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) - outer_count += 1 - - return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss) - - -def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs): - r""" - Returns for a fixed admissible transport plan, - the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})` - - .. math:: - \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2 - - - Such that: - - - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` - - :math:`\mathbf{w} \geq \mathbf{0}_D` - - Where : - - - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. - - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points. - - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. - - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`. - - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`. - - reg is the regularization coefficient. - - The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38] - - Parameters - ---------- - - C : array-like, shape (ns, ns) - Metric/Graph cost matrix. - Cdict : list of D array-like, shape (nt,nt) - Metric/Graph cost matrices composing the dictionary on which to embed C. - Each matrix in the dictionary must have the same size (nt,nt). - Cembedded: array-like, shape (nt,nt) - Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. - w: array-like, shape (D,) - Linear unmixing of the input structure onto the dictionary - const_q: array-like, shape (nt,nt) - product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. - T: array-like, shape (ns,nt) - fixed transport plan between the input structure and its representation in the dictionary. - p : array-like, shape (ns,) - Distribution in the source space. - q : array-like, shape (nt,) - Distribution in the embedding space depicted by the dictionary. - reg : float, optional. - Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. - - Returns - ------- - w: ndarray (D,) - optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing. - """ - convergence_criterion = np.inf - current_loss = starting_loss - count = 0 - const_TCT = np.transpose(C.dot(T)).dot(T) - - while (convergence_criterion > tol) and (count < max_iter): - - previous_loss = current_loss - # 1) Compute gradient at current point w - grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) - grad_w -= 2 * reg * w - - # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w - min_ = np.min(grad_w) - x = (grad_w == min_).astype(np.float64) - x /= np.sum(x) - - # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c - gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg) - - # 4) Updates: w <-- (1-gamma)*w + gamma*x - w += gamma * (x - w) - Cembedded += gamma * Cembedded_diff - current_loss += a * (gamma**2) + b * gamma - - if previous_loss != 0: # not that the loss can be negative if reg >0 - convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) - else: # handle numerical issues around 0 - convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) - count += 1 - - return w, Cembedded, current_loss - - -def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs): - r""" - Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing - .. math:: - \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2 - - - Such that: - - - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` - - Parameters - ---------- - - w : array-like, shape (D,) - Unmixing. - grad_w : array-like, shape (D, D) - Gradient of the reconstruction loss with respect to w. - x: array-like, shape (D,) - Conditional gradient direction. - Cdict : list of D array-like, shape (nt,nt) - Metric/Graph cost matrices composing the dictionary on which to embed C. - Each matrix in the dictionary must have the same size (nt,nt). - Cembedded: array-like, shape (nt,nt) - Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. - const_q: array-like, shape (nt,nt) - product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. - const_TCT: array-like, shape (nt, nt) - :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. - Returns - ------- - gamma: float - Optimal value for the line-search step - a: float - Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss - b: float - Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss - Cembedded_diff: numpy array, shape (nt, nt) - Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. - reg : float, optional. - Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`. - """ - - # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c - Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) - Cembedded_diff = Cembedded_x - Cembedded - trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) - trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) - a = trace_diffx - trace_diffw - b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) - if reg != 0: - a -= reg * np.sum((x - w)**2) - b -= 2 * reg * np.sum(w * (x - w)) - - if a > 0: - gamma = min(1, max(0, - b / (2 * a))) - elif a + b < 0: - gamma = 1 - else: - gamma = 0 - - return gamma, a, b, Cembedded_diff - - -def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1., - Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False, - tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): - r""" - Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s` - - .. math:: - \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2 - - - Such that :math:`\forall s \leq S` : - - - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` - - :math:`\mathbf{w_s} \geq \mathbf{0}_D` - - Where : - - - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. - - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d. - - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. - - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. - - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` - - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. - - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein - - reg is the regularization coefficient. - - - The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38] - - Parameters - ---------- - Cs : list of S symmetric array-like, shape (ns, ns) - List of Metric/Graph cost matrices of variable size (ns,ns). - Ys : list of S array-like, shape (ns, d) - List of feature matrix of variable size (ns,d) with d fixed. - D: int - Number of dictionary atoms to learn - nt: int - Number of samples within each dictionary atoms - alpha : float - Trade-off parameter of Fused Gromov-Wasserstein - reg : float, optional - Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. - ps : list of S array-like, shape (ns,), optional - Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. - q : array-like, shape (nt,), optional - Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. - epochs: int, optional - Number of epochs used to learn the dictionary. Default is 32. - batch_size: int, optional - Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. - learning_rate_C: float, optional - Learning rate used for the stochastic gradient descent on Cdict. Default is 1. - learning_rate_Y: float, optional - Learning rate used for the stochastic gradient descent on Ydict. Default is 1. - Cdict_init: list of D array-like with shape (nt, nt), optional - Used to initialize the dictionary structures Cdict. - If set to None (Default), the dictionary will be initialized randomly. - Else Cdict must have shape (D, nt, nt) i.e match provided shape features. - Ydict_init: list of D array-like with shape (nt, d), optional - Used to initialize the dictionary features Ydict. - If set to None, the dictionary features will be initialized randomly. - Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features. - projection: str, optional - If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary - Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' - log: bool, optional - If set to True, losses evolution by batches and epochs are tracked. Default is False. - use_adam_optimizer: bool, optional - If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. - Else perform SGD with fixed learning rate. Default is True. - tol_outer : float, optional - Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. - tol_inner : float, optional - Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. - max_iter_outer : int, optional - Maximum number of iterations for the BCD. Default is 20. - max_iter_inner : int, optional - Maximum number of iterations for the Conjugate Gradient. Default is 200. - verbose : bool, optional - Print the reconstruction loss every epoch. Default is False. - - Returns - ------- - - Cdict_best_state : D array-like, shape (D,nt,nt) - Metric/Graph cost matrices composing the dictionary. - The dictionary leading to the best loss over an epoch is saved and returned. - Ydict_best_state : D array-like, shape (D,nt,d) - Feature matrices composing the dictionary. - The dictionary leading to the best loss over an epoch is saved and returned. - log: dict - If use_log is True, contains loss evolutions by batches and epoches. - References - ------- - - ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. - "Online Graph Dictionary Learning" - International Conference on Machine Learning (ICML). 2021. - """ - Cs0, Ys0 = Cs, Ys - nx = get_backend(*Cs0, *Ys0) - Cs = [nx.to_numpy(C) for C in Cs0] - Ys = [nx.to_numpy(Y) for Y in Ys0] - - d = Ys[0].shape[-1] - dataset_size = len(Cs) - - if ps is None: - ps = [unif(C.shape[0]) for C in Cs] - else: - ps = [nx.to_numpy(p) for p in ps] - if q is None: - q = unif(nt) - else: - q = nx.to_numpy(q) - - if Cdict_init is None: - # Initialize randomly structures of dictionary atoms based on samples - dataset_means = [C.mean() for C in Cs] - Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) - else: - Cdict = nx.to_numpy(Cdict_init).copy() - assert Cdict.shape == (D, nt, nt) - if Ydict_init is None: - # Initialize randomly features of dictionary atoms based on samples distribution by feature component - dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys]) - Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d)) - else: - Ydict = nx.to_numpy(Ydict_init).copy() - assert Ydict.shape == (D, nt, d) - - if 'symmetric' in projection: - Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) - if 'nonnegative' in projection: - Cdict[Cdict < 0.] = 0. - - if use_adam_optimizer: - adam_moments_C = _initialize_adam_optimizer(Cdict) - adam_moments_Y = _initialize_adam_optimizer(Ydict) - - log = {'loss_batches': [], 'loss_epochs': []} - const_q = q[:, None] * q[None, :] - diag_q = np.diag(q) - Cdict_best_state = Cdict.copy() - Ydict_best_state = Ydict.copy() - loss_best_state = np.inf - if batch_size > dataset_size: - batch_size = dataset_size - iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) - - for epoch in range(epochs): - cumulated_loss_over_epoch = 0. - - for _ in range(iter_by_epoch): - - # Batch iterations - batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) - cumulated_loss_over_batch = 0. - unmixings = np.zeros((batch_size, D)) - Cs_embedded = np.zeros((batch_size, nt, nt)) - Ys_embedded = np.zeros((batch_size, nt, d)) - Ts = [None] * batch_size - - for batch_idx, C_idx in enumerate(batch): - # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch - unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing( - Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q, - tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner - ) - cumulated_loss_over_batch += current_loss - cumulated_loss_over_epoch += cumulated_loss_over_batch - if use_log: - log['loss_batches'].append(cumulated_loss_over_batch) - - # Stochastic projected gradient step over dictionary atoms - grad_Cdict = np.zeros_like(Cdict) - grad_Ydict = np.zeros_like(Ydict) - - for batch_idx, C_idx in enumerate(batch): - shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) - shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx]) - grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] - grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :] - grad_Cdict *= 2 / batch_size - grad_Ydict *= 2 / batch_size - - if use_adam_optimizer: - Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C) - Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y) - else: - Cdict -= learning_rate_C * grad_Cdict - Ydict -= learning_rate_Y * grad_Ydict - - if 'symmetric' in projection: - Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) - if 'nonnegative' in projection: - Cdict[Cdict < 0.] = 0. - - if use_log: - log['loss_epochs'].append(cumulated_loss_over_epoch) - if loss_best_state > cumulated_loss_over_epoch: - loss_best_state = cumulated_loss_over_epoch - Cdict_best_state = Cdict.copy() - Ydict_best_state = Ydict.copy() - if verbose: - print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) - - return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log - - -def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs): - r""" - Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` - - .. math:: - \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 - - such that, :math:`\forall s \leq S` : - - - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` - - :math:`\mathbf{w_s} \geq \mathbf{0}_D` - - Where : - - - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. - - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. - - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. - - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. - - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` - - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. - - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein - - reg is the regularization coefficient. - - The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 6. - - Parameters - ---------- - C : array-like, shape (ns, ns) - Metric/Graph cost matrix. - Y : array-like, shape (ns, d) - Feature matrix. - Cdict : D array-like, shape (D,nt,nt) - Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). - Ydict : D array-like, shape (D,nt,d) - Feature matrices composing the dictionary on which to embed (C,Y). - alpha: float, - Trade-off parameter of Fused Gromov-Wasserstein. - reg : float, optional - Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. - p : array-like, shape (ns,), optional - Distribution in the source space C. Default is None and corresponds to uniform distribution. - q : array-like, shape (nt,), optional - Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. - tol_outer : float, optional - Solver precision for the BCD algorithm. - tol_inner : float, optional - Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. - max_iter_outer : int, optional - Maximum number of iterations for the BCD. Default is 20. - max_iter_inner : int, optional - Maximum number of iterations for the Conjugate Gradient. Default is 200. - - Returns - ------- - w: array-like, shape (D,) - fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary. - Cembedded: array-like, shape (nt,nt) - embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. - Yembedded: array-like, shape (nt,d) - embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`. - T: array-like (ns,nt) - Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`. - current_loss: float - reconstruction error - References - ------- - - ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. - "Online Graph Dictionary Learning" - International Conference on Machine Learning (ICML). 2021. - """ - C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict - nx = get_backend(C0, Y0, Cdict0, Ydict0) - C = nx.to_numpy(C0) - Y = nx.to_numpy(Y0) - Cdict = nx.to_numpy(Cdict0) - Ydict = nx.to_numpy(Ydict0) - - if p is None: - p = unif(C.shape[0]) - else: - p = nx.to_numpy(p) - if q is None: - q = unif(Cdict.shape[-1]) - else: - q = nx.to_numpy(q) - - T = p[:, None] * q[None, :] - D = len(Cdict) - d = Y.shape[-1] - w = unif(D) # Initialize with uniform weights - ns = C.shape[-1] - nt = Cdict.shape[-1] - - # modeling (C,Y) - Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) - Yembedded = np.sum(w[:, None, None] * Ydict, axis=0) - - # constants depending on q - const_q = q[:, None] * q[None, :] - diag_q = np.diag(q) - # Trackers for BCD convergence - convergence_criterion = np.inf - current_loss = 10**15 - outer_count = 0 - Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix - - while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): - previous_loss = current_loss - - # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w - Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T) - M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features - T, log = fused_gromov_wasserstein(M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, armijo=False, G0=T, log=True) - current_loss = log['fgw_dist'] - if reg != 0: - current_loss -= reg * np.sum(w**2) - - # 2. Solve linear unmixing problem over w with a fixed transport plan T - w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, - T, p, q, const_q, diag_q, current_loss, alpha, reg, - tol=tol_inner, max_iter=max_iter_inner, **kwargs) - if previous_loss != 0: - convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) - else: - convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) - outer_count += 1 - - return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss) - - -def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs): - r""" - Returns for a fixed admissible transport plan, - the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})` - - .. math:: - \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2 - - Such that : - - - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` - - :math:`\mathbf{w} \geq \mathbf{0}_D` - - Where : - - - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. - - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. - - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. - - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. - - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` - - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. - - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}` - - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein - - reg is the regularization coefficient. - - The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38], algorithm 7. - - Parameters - ---------- - - C : array-like, shape (ns, ns) - Metric/Graph cost matrix. - Y : array-like, shape (ns, d) - Feature matrix. - Cdict : list of D array-like, shape (nt,nt) - Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). - Each matrix in the dictionary must have the same size (nt,nt). - Ydict : list of D array-like, shape (nt,d) - Feature matrices composing the dictionary on which to embed (C,Y). - Each matrix in the dictionary must have the same size (nt,d). - Cembedded: array-like, shape (nt,nt) - Embedded structure of (C,Y) onto the dictionary - Yembedded: array-like, shape (nt,d) - Embedded features of (C,Y) onto the dictionary - w: array-like, shape (n_D,) - Linear unmixing of (C,Y) onto (Cdict,Ydict) - const_q: array-like, shape (nt,nt) - product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution. - diag_q: array-like, shape (nt,nt) - diagonal matrix with values of q on the diagonal. - T: array-like, shape (ns,nt) - fixed transport plan between (C,Y) and its model - p : array-like, shape (ns,) - Distribution in the source space (C,Y). - q : array-like, shape (nt,) - Distribution in the embedding space depicted by the dictionary. - alpha: float, - Trade-off parameter of Fused Gromov-Wasserstein. - reg : float, optional - Coefficient of the negative quadratic regularization used to promote sparsity of w. - - Returns - ------- - w: ndarray (D,) - linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing. - """ - convergence_criterion = np.inf - current_loss = starting_loss - count = 0 - const_TCT = np.transpose(C.dot(T)).dot(T) - ones_ns_d = np.ones(Y.shape) - - while (convergence_criterion > tol) and (count < max_iter): - previous_loss = current_loss - - # 1) Compute gradient at current point w - # structure - grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) - # feature - grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2)) - grad_w -= reg * w - grad_w *= 2 - - # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w - min_ = np.min(grad_w) - x = (grad_w == min_).astype(np.float64) - x /= np.sum(x) - - # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c - gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg) - - # 4) Updates: w <-- (1-gamma)*w + gamma*x - w += gamma * (x - w) - Cembedded += gamma * Cembedded_diff - Yembedded += gamma * Yembedded_diff - current_loss += a * (gamma**2) + b * gamma - - if previous_loss != 0: - convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) - else: - convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) - count += 1 - - return w, Cembedded, Yembedded, current_loss - - -def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs): - r""" - Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing - .. math:: - \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2 - - - Such that : - - - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` - - Parameters - ---------- - - w : array-like, shape (D,) - Unmixing. - grad_w : array-like, shape (D, D) - Gradient of the reconstruction loss with respect to w. - x: array-like, shape (D,) - Conditional gradient direction. - Y: arrat-like, shape (ns,d) - Feature matrix of the input space - Cdict : list of D array-like, shape (nt, nt) - Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). - Each matrix in the dictionary must have the same size (nt,nt). - Ydict : list of D array-like, shape (nt, d) - Feature matrices composing the dictionary on which to embed (C,Y). - Each matrix in the dictionary must have the same size (nt,d). - Cembedded: array-like, shape (nt, nt) - Embedded structure of (C,Y) onto the dictionary - Yembedded: array-like, shape (nt, d) - Embedded features of (C,Y) onto the dictionary - T: array-like, shape (ns, nt) - Fixed transport plan between (C,Y) and its current model. - const_q: array-like, shape (nt,nt) - product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. - const_TCT: array-like, shape (nt, nt) - :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. - ones_ns_d: array-like, shape (ns, d) - :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations. - alpha: float, - Trade-off parameter of Fused Gromov-Wasserstein. - reg : float, optional - Coefficient of the negative quadratic regularization used to promote sparsity of w. - - Returns - ------- - gamma: float - Optimal value for the line-search step - a: float - Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss - b: float - Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss - Cembedded_diff: numpy array, shape (nt, nt) - Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. - Yembedded_diff: numpy array, shape (nt, nt) - Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. - """ - # polynomial coefficients from quadratic objective (with respect to w) on structures - Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) - Cembedded_diff = Cembedded_x - Cembedded - trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) - trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) - # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss - a_gw = trace_diffx - trace_diffw - b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) - - # polynomial coefficient from quadratic objective (with respect to w) on features - Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0) - Yembedded_diff = Yembedded_x - Yembedded - # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss - a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T) - b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T))) - - a = alpha * a_gw + (1 - alpha) * a_w - b = alpha * b_gw + (1 - alpha) * b_w - if reg != 0: - a -= reg * np.sum((x - w)**2) - b -= 2 * reg * np.sum(w * (x - w)) - if a > 0: - gamma = min(1, max(0, -b / (2 * a))) - elif a + b < 0: - gamma = 1 - else: - gamma = 0 - - return gamma, a, b, Cembedded_diff, Yembedded_diff diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py new file mode 100644 index 0000000..6184edf --- /dev/null +++ b/ot/gromov/__init__.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +""" +Solvers related to Gromov-Wasserstein problems. + +""" + +# Author: Remi Flamary +# Cedric Vincent-Cuaz +# +# License: MIT License + +# All submodules and packages +from ._utils import (init_matrix, tensor_product, gwloss, gwggrad, + update_square_loss, update_kl_loss, + init_matrix_semirelaxed) +from ._gw import (gromov_wasserstein, gromov_wasserstein2, + fused_gromov_wasserstein, fused_gromov_wasserstein2, + solve_gromov_linesearch, gromov_barycenters, fgw_barycenters, + update_structure_matrix, update_feature_matrix) +from ._bregman import (entropic_gromov_wasserstein, + entropic_gromov_wasserstein2, + entropic_gromov_barycenters) +from ._estimators import (GW_distance_estimation, pointwise_gromov_wasserstein, + sampled_gromov_wasserstein) +from ._semirelaxed import (semirelaxed_gromov_wasserstein, + semirelaxed_gromov_wasserstein2, + semirelaxed_fused_gromov_wasserstein, + semirelaxed_fused_gromov_wasserstein2, + solve_semirelaxed_gromov_linesearch) +from ._dictionary import (gromov_wasserstein_dictionary_learning, + gromov_wasserstein_linear_unmixing, + fused_gromov_wasserstein_dictionary_learning, + fused_gromov_wasserstein_linear_unmixing) + + +__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', + 'update_square_loss', 'update_kl_loss', 'init_matrix_semirelaxed', + 'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein', + 'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters', + 'fgw_barycenters', 'update_structure_matrix', 'update_feature_matrix', + 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2', + 'entropic_gromov_barycenters', 'GW_distance_estimation', + 'pointwise_gromov_wasserstein', 'sampled_gromov_wasserstein', + 'semirelaxed_gromov_wasserstein', 'semirelaxed_gromov_wasserstein2', + 'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2', + 'solve_semirelaxed_gromov_linesearch', 'gromov_wasserstein_dictionary_learning', + 'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning', + 'fused_gromov_wasserstein_linear_unmixing'] diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py new file mode 100644 index 0000000..5b2f959 --- /dev/null +++ b/ot/gromov/_bregman.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- +""" +Bregman projections solvers for entropic Gromov-Wasserstein +""" + +# Author: Erwan Vautier +# Nicolas Courty +# Rémi Flamary +# Titouan Vayer +# Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np + + +from ..bregman import sinkhorn +from ..utils import dist, list_to_array, check_random_state +from ..backend import get_backend + +from ._utils import init_matrix, gwloss, gwggrad +from ._utils import update_square_loss, update_kl_loss + + +def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, G0=None, + max_iter=1000, tol=1e-9, verbose=False, log=False): + r""" + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + - `H`: entropy + + .. note:: If the inner solver `ot.sinkhorn` did not convergence, the + optimal coupling :math:`\mathbf{T}` returned by this function does not + necessarily satisfy the marginal constraints + :math:`\mathbf{T}\mathbf{1}=\mathbf{p}` and + :math:`\mathbf{T}^T\mathbf{1}=\mathbf{q}`. So the returned + Gromov-Wasserstein loss does not necessarily satisfy distance + properties and may be negative. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : string + Loss function used for the solver either 'square_loss' or 'kl_loss' + epsilon : float + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + Record log if True. + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal coupling between the two spaces + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein + distance between networks and stable network invariants. + Information and Inference: A Journal of the IMA, 8(4), 757-787. + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + if G0 is None: + nx = get_backend(p, q, C1, C2) + G0 = nx.outer(p, q) + else: + nx = get_backend(p, q, C1, C2, G0) + T = G0 + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx) + if symmetric is None: + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + if not symmetric: + constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx) + cpt = 0 + err = 1 + + if log: + log = {'err': []} + + while (err > tol and cpt < max_iter): + + Tprev = T + + # compute the gradient + if symmetric: + tens = gwggrad(constC, hC1, hC2, T, nx) + else: + tens = 0.5 * (gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx)) + T = sinkhorn(p, q, tens, epsilon, method='sinkhorn') + + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = nx.norm(T - Tprev) + + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + + if log: + log['gw_dist'] = gwloss(constC, hC1, hC2, T, nx) + return T, log + else: + return T + + +def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None, G0=None, + max_iter=1000, tol=1e-9, verbose=False, log=False): + r""" + Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + + The function solves the following optimization problem: + + .. math:: + GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) + \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + - `H`: entropy + + .. note:: If the inner solver `ot.sinkhorn` did not convergence, the + optimal coupling :math:`\mathbf{T}` returned by this function does not + necessarily satisfy the marginal constraints + :math:`\mathbf{T}\mathbf{1}=\mathbf{p}` and + :math:`\mathbf{T}^T\mathbf{1}=\mathbf{q}`. So the returned + Gromov-Wasserstein loss does not necessarily satisfy distance + properties and may be negative. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : str + Loss function used for the solver either 'square_loss' or 'kl_loss' + epsilon : float + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + Record log if True. + + Returns + ------- + gw_dist : float + Gromov-Wasserstein distance + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + gw, logv = entropic_gromov_wasserstein( + C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter, tol, verbose, log=True) + + logv['T'] = gw + + if log: + return logv['gw_dist'], logv + else: + return logv['gw_dist'] + + +def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmetric=True, + max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None): + r""" + Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` + + The function solves the following optimization problem: + + .. math:: + + \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) + + Where : + + - :math:`\mathbf{C}_s`: metric cost matrix + - :math:`\mathbf{p}_s`: distribution + + Parameters + ---------- + N : int + Size of the targeted barycenter + Cs : list of S array-like of shape (ns,ns) + Metric cost matrices + ps : list of S array-like of shape (ns,) + Sample weights in the `S` spaces + p : array-like, shape(N,) + Weights in the targeted barycenter + lambdas : list of float + List of the `S` spaces' weights. + loss_fun : callable + Tensor-matrix multiplication function based on specific loss function. + update : callable + function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates + :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings + calculated at each iteration + epsilon : float + Regularization term >0 + symmetric : bool, optional. + Either structures are to be assumed symmetric or not. Default value is True. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. + init_C : bool | array-like, shape (N, N) + Random initial value for the :math:`\mathbf{C}` matrix provided by user. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + C : array-like, shape (`N`, `N`) + Similarity matrix in the barycenter space (permutated arbitrarily) + log : dict + Log dictionary of error during iterations. Return only if `log=True` in parameters. + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + """ + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + p = list_to_array(p) + nx = get_backend(*Cs, *ps, p) + + S = len(Cs) + + # Initialization of C : random SPD matrix (if not provided by user) + if init_C is None: + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) + C = dist(xalea, xalea) + C /= C.max() + C = nx.from_numpy(C, type_as=p) + else: + C = init_C + + cpt = 0 + err = 1 + + error = [] + + while (err > tol) and (cpt < max_iter): + Cprev = C + + T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, None, + max_iter, 1e-4, verbose, log=False) for s in range(S)] + if loss_fun == 'square_loss': + C = update_square_loss(p, lambdas, T, Cs) + + elif loss_fun == 'kl_loss': + C = update_kl_loss(p, lambdas, T, Cs) + + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = nx.norm(C - Cprev) + error.append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + + if log: + return C, {"err": error} + else: + return C diff --git a/ot/gromov/_dictionary.py b/ot/gromov/_dictionary.py new file mode 100644 index 0000000..5b32671 --- /dev/null +++ b/ot/gromov/_dictionary.py @@ -0,0 +1,1008 @@ +# -*- coding: utf-8 -*- +""" +(Fused) Gromov-Wasserstein dictionary learning. +""" + +# Author: Rémi Flamary +# Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np + + +from ..utils import unif +from ..backend import get_backend +from ._gw import gromov_wasserstein, fused_gromov_wasserstein + + +def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): + r""" + Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s` + + .. math:: + \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2 + + such that, :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - reg is the regularization coefficient. + + The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38]_ + + Parameters + ---------- + Cs : list of S symmetric array-like, shape (ns, ns) + List of Metric/Graph cost matrices of variable size (ns, ns). + D: int + Number of dictionary atoms to learn + nt: int + Number of samples within each dictionary atoms + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + ps : list of S array-like, shape (ns,), optional + Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. + q : array-like, shape (nt,), optional + Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. + epochs: int, optional + Number of epochs used to learn the dictionary. Default is 32. + batch_size: int, optional + Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. + learning_rate: float, optional + Learning rate used for the stochastic gradient descent. Default is 1. + Cdict_init: list of D array-like with shape (nt, nt), optional + Used to initialize the dictionary. + If set to None (Default), the dictionary will be initialized randomly. + Else Cdict must have shape (D, nt, nt) i.e match provided shape features. + projection: str , optional + If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary + Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' + log: bool, optional + If set to True, losses evolution by batches and epochs are tracked. Default is False. + use_adam_optimizer: bool, optional + If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. + Else perform SGD with fixed learning rate. Default is True. + tol_outer : float, optional + Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + verbose : bool, optional + Print the reconstruction loss every epoch. Default is False. + + Returns + ------- + + Cdict_best_state : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + log: dict + If use_log is True, contains loss evolutions by batches and epochs. + References + ------- + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + """ + # Handle backend of non-optional arguments + Cs0 = Cs + nx = get_backend(*Cs0) + Cs = [nx.to_numpy(C) for C in Cs0] + dataset_size = len(Cs) + # Handle backend of optional arguments + if ps is None: + ps = [unif(C.shape[0]) for C in Cs] + else: + ps = [nx.to_numpy(p) for p in ps] + if q is None: + q = unif(nt) + else: + q = nx.to_numpy(q) + if Cdict_init is None: + # Initialize randomly structures of dictionary atoms based on samples + dataset_means = [C.mean() for C in Cs] + Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) + else: + Cdict = nx.to_numpy(Cdict_init).copy() + assert Cdict.shape == (D, nt, nt) + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + symmetric = True + else: + symmetric = False + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0 + if use_adam_optimizer: + adam_moments = _initialize_adam_optimizer(Cdict) + + log = {'loss_batches': [], 'loss_epochs': []} + const_q = q[:, None] * q[None, :] + Cdict_best_state = Cdict.copy() + loss_best_state = np.inf + if batch_size > dataset_size: + batch_size = dataset_size + iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) + + for epoch in range(epochs): + cumulated_loss_over_epoch = 0. + + for _ in range(iter_by_epoch): + # batch sampling + batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) + cumulated_loss_over_batch = 0. + unmixings = np.zeros((batch_size, D)) + Cs_embedded = np.zeros((batch_size, nt, nt)) + Ts = [None] * batch_size + + for batch_idx, C_idx in enumerate(batch): + # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch + unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing( + Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner, + max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner, symmetric=symmetric, **kwargs + ) + cumulated_loss_over_batch += current_loss + cumulated_loss_over_epoch += cumulated_loss_over_batch + + if use_log: + log['loss_batches'].append(cumulated_loss_over_batch) + + # Stochastic projected gradient step over dictionary atoms + grad_Cdict = np.zeros_like(Cdict) + for batch_idx, C_idx in enumerate(batch): + shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) + grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] + grad_Cdict *= 2 / batch_size + if use_adam_optimizer: + Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments) + else: + Cdict -= learning_rate * grad_Cdict + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_log: + log['loss_epochs'].append(cumulated_loss_over_epoch) + if loss_best_state > cumulated_loss_over_epoch: + loss_best_state = cumulated_loss_over_epoch + Cdict_best_state = Cdict.copy() + if verbose: + print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) + + return nx.from_numpy(Cdict_best_state), log + + +def _initialize_adam_optimizer(variable): + + # Initialization for our numpy implementation of adam optimizer + atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor + atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor + atoms_adam_count = 1 + + return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count} + + +def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09): + + adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad + adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2) + unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count']) + unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count']) + variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps) + adam_moments['count'] += 1 + + return variable, adam_moments + + +def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, symmetric=None, **kwargs): + r""" + Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`. + + .. math:: + \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 + + such that: + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt. + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38]_ , algorithm 1. + + Parameters + ---------- + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Cdict : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. + p : array-like, shape (ns,), optional + Distribution in the source space C. Default is None and corresponds to uniform distribution. + q : array-like, shape (nt,), optional + Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. + tol_outer : float, optional + Solver precision for the BCD algorithm. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + + Returns + ------- + w: array-like, shape (D,) + gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary. + Cembedded: array-like, shape (nt,nt) + embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. + T: array-like (ns, nt) + Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})` + current_loss: float + reconstruction error + References + ------- + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + """ + C0, Cdict0 = C, Cdict + nx = get_backend(C0, Cdict0) + C = nx.to_numpy(C0) + Cdict = nx.to_numpy(Cdict0) + if p is None: + p = unif(C.shape[0]) + else: + p = nx.to_numpy(p) + + if q is None: + q = unif(Cdict.shape[-1]) + else: + q = nx.to_numpy(q) + + T = p[:, None] * q[None, :] + D = len(Cdict) + + w = unif(D) # Initialize uniformly the unmixing w + Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) + + const_q = q[:, None] * q[None, :] + # Trackers for BCD convergence + convergence_criterion = np.inf + current_loss = 10**15 + outer_count = 0 + + while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): + previous_loss = current_loss + # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w + T, log = gromov_wasserstein( + C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, + max_iter=max_iter_inner, tol_rel=tol_inner, tol_abs=0., log=True, armijo=False, symmetric=symmetric, **kwargs) + current_loss = log['gw_dist'] + if reg != 0: + current_loss -= reg * np.sum(w**2) + + # 2. Solve linear unmixing problem over w with a fixed transport plan T + w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing( + C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T, + starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs + ) + + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: # handle numerical issues around 0 + convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) + outer_count += 1 + + return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss) + + +def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs): + r""" + Returns for a fixed admissible transport plan, + the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})` + + .. math:: + \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2 + + + Such that: + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points. + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. + - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`. + - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`. + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]_ + + Parameters + ---------- + + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + Each matrix in the dictionary must have the same size (nt,nt). + Cembedded: array-like, shape (nt,nt) + Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. + w: array-like, shape (D,) + Linear unmixing of the input structure onto the dictionary + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + T: array-like, shape (ns,nt) + fixed transport plan between the input structure and its representation in the dictionary. + p : array-like, shape (ns,) + Distribution in the source space. + q : array-like, shape (nt,) + Distribution in the embedding space depicted by the dictionary. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. + + Returns + ------- + w: ndarray (D,) + optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing. + """ + convergence_criterion = np.inf + current_loss = starting_loss + count = 0 + const_TCT = np.transpose(C.dot(T)).dot(T) + + while (convergence_criterion > tol) and (count < max_iter): + + previous_loss = current_loss + # 1) Compute gradient at current point w + grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) + grad_w -= 2 * reg * w + + # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w + min_ = np.min(grad_w) + x = (grad_w == min_).astype(np.float64) + x /= np.sum(x) + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg) + + # 4) Updates: w <-- (1-gamma)*w + gamma*x + w += gamma * (x - w) + Cembedded += gamma * Cembedded_diff + current_loss += a * (gamma**2) + b * gamma + + if previous_loss != 0: # not that the loss can be negative if reg >0 + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: # handle numerical issues around 0 + convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) + count += 1 + + return w, Cembedded, current_loss + + +def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs): + r""" + Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing + .. math:: + \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2 + + + Such that: + + - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` + + Parameters + ---------- + + w : array-like, shape (D,) + Unmixing. + grad_w : array-like, shape (D, D) + Gradient of the reconstruction loss with respect to w. + x: array-like, shape (D,) + Conditional gradient direction. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + Each matrix in the dictionary must have the same size (nt,nt). + Cembedded: array-like, shape (nt,nt) + Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + const_TCT: array-like, shape (nt, nt) + :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. + Returns + ------- + gamma: float + Optimal value for the line-search step + a: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + b: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + Cembedded_diff: numpy array, shape (nt, nt) + Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`. + """ + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) + Cembedded_diff = Cembedded_x - Cembedded + trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) + trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) + a = trace_diffx - trace_diffw + b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) + if reg != 0: + a -= reg * np.sum((x - w)**2) + b -= 2 * reg * np.sum(w * (x - w)) + + if a > 0: + gamma = min(1, max(0, - b / (2 * a))) + elif a + b < 0: + gamma = 1 + else: + gamma = 0 + + return gamma, a, b, Cembedded_diff + + +def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1., + Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): + r""" + Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s` + + .. math:: + \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2 + + + Such that :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + + The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38]_ + + Parameters + ---------- + Cs : list of S symmetric array-like, shape (ns, ns) + List of Metric/Graph cost matrices of variable size (ns,ns). + Ys : list of S array-like, shape (ns, d) + List of feature matrix of variable size (ns,d) with d fixed. + D: int + Number of dictionary atoms to learn + nt: int + Number of samples within each dictionary atoms + alpha : float + Trade-off parameter of Fused Gromov-Wasserstein + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + ps : list of S array-like, shape (ns,), optional + Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. + q : array-like, shape (nt,), optional + Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. + epochs: int, optional + Number of epochs used to learn the dictionary. Default is 32. + batch_size: int, optional + Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. + learning_rate_C: float, optional + Learning rate used for the stochastic gradient descent on Cdict. Default is 1. + learning_rate_Y: float, optional + Learning rate used for the stochastic gradient descent on Ydict. Default is 1. + Cdict_init: list of D array-like with shape (nt, nt), optional + Used to initialize the dictionary structures Cdict. + If set to None (Default), the dictionary will be initialized randomly. + Else Cdict must have shape (D, nt, nt) i.e match provided shape features. + Ydict_init: list of D array-like with shape (nt, d), optional + Used to initialize the dictionary features Ydict. + If set to None, the dictionary features will be initialized randomly. + Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features. + projection: str, optional + If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary + Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' + log: bool, optional + If set to True, losses evolution by batches and epochs are tracked. Default is False. + use_adam_optimizer: bool, optional + If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. + Else perform SGD with fixed learning rate. Default is True. + tol_outer : float, optional + Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + verbose : bool, optional + Print the reconstruction loss every epoch. Default is False. + + Returns + ------- + + Cdict_best_state : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + Ydict_best_state : D array-like, shape (D,nt,d) + Feature matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + log: dict + If use_log is True, contains loss evolutions by batches and epoches. + References + ------- + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + """ + Cs0, Ys0 = Cs, Ys + nx = get_backend(*Cs0, *Ys0) + Cs = [nx.to_numpy(C) for C in Cs0] + Ys = [nx.to_numpy(Y) for Y in Ys0] + + d = Ys[0].shape[-1] + dataset_size = len(Cs) + + if ps is None: + ps = [unif(C.shape[0]) for C in Cs] + else: + ps = [nx.to_numpy(p) for p in ps] + if q is None: + q = unif(nt) + else: + q = nx.to_numpy(q) + + if Cdict_init is None: + # Initialize randomly structures of dictionary atoms based on samples + dataset_means = [C.mean() for C in Cs] + Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) + else: + Cdict = nx.to_numpy(Cdict_init).copy() + assert Cdict.shape == (D, nt, nt) + if Ydict_init is None: + # Initialize randomly features of dictionary atoms based on samples distribution by feature component + dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys]) + Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d)) + else: + Ydict = nx.to_numpy(Ydict_init).copy() + assert Ydict.shape == (D, nt, d) + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + symmetric = True + else: + symmetric = False + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_adam_optimizer: + adam_moments_C = _initialize_adam_optimizer(Cdict) + adam_moments_Y = _initialize_adam_optimizer(Ydict) + + log = {'loss_batches': [], 'loss_epochs': []} + const_q = q[:, None] * q[None, :] + diag_q = np.diag(q) + Cdict_best_state = Cdict.copy() + Ydict_best_state = Ydict.copy() + loss_best_state = np.inf + if batch_size > dataset_size: + batch_size = dataset_size + iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) + + for epoch in range(epochs): + cumulated_loss_over_epoch = 0. + + for _ in range(iter_by_epoch): + + # Batch iterations + batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) + cumulated_loss_over_batch = 0. + unmixings = np.zeros((batch_size, D)) + Cs_embedded = np.zeros((batch_size, nt, nt)) + Ys_embedded = np.zeros((batch_size, nt, d)) + Ts = [None] * batch_size + + for batch_idx, C_idx in enumerate(batch): + # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch + unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing( + Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q, + tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner, symmetric=symmetric, **kwargs + ) + cumulated_loss_over_batch += current_loss + cumulated_loss_over_epoch += cumulated_loss_over_batch + if use_log: + log['loss_batches'].append(cumulated_loss_over_batch) + + # Stochastic projected gradient step over dictionary atoms + grad_Cdict = np.zeros_like(Cdict) + grad_Ydict = np.zeros_like(Ydict) + + for batch_idx, C_idx in enumerate(batch): + shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) + shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx]) + grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] + grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :] + grad_Cdict *= 2 / batch_size + grad_Ydict *= 2 / batch_size + + if use_adam_optimizer: + Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C) + Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y) + else: + Cdict -= learning_rate_C * grad_Cdict + Ydict -= learning_rate_Y * grad_Ydict + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_log: + log['loss_epochs'].append(cumulated_loss_over_epoch) + if loss_best_state > cumulated_loss_over_epoch: + loss_best_state = cumulated_loss_over_epoch + Cdict_best_state = Cdict.copy() + Ydict_best_state = Ydict.copy() + if verbose: + print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) + + return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log + + +def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), + tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, symmetric=True, **kwargs): + r""" + Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` + + .. math:: + \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 + + such that, :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38]_, algorithm 6. + + Parameters + ---------- + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Y : array-like, shape (ns, d) + Feature matrix. + Cdict : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Ydict : D array-like, shape (D,nt,d) + Feature matrices composing the dictionary on which to embed (C,Y). + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + p : array-like, shape (ns,), optional + Distribution in the source space C. Default is None and corresponds to uniform distribution. + q : array-like, shape (nt,), optional + Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. + tol_outer : float, optional + Solver precision for the BCD algorithm. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + + Returns + ------- + w: array-like, shape (D,) + fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary. + Cembedded: array-like, shape (nt,nt) + embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. + Yembedded: array-like, shape (nt,d) + embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`. + T: array-like (ns,nt) + Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`. + current_loss: float + reconstruction error + References + ------- + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + """ + C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict + nx = get_backend(C0, Y0, Cdict0, Ydict0) + C = nx.to_numpy(C0) + Y = nx.to_numpy(Y0) + Cdict = nx.to_numpy(Cdict0) + Ydict = nx.to_numpy(Ydict0) + + if p is None: + p = unif(C.shape[0]) + else: + p = nx.to_numpy(p) + if q is None: + q = unif(Cdict.shape[-1]) + else: + q = nx.to_numpy(q) + + T = p[:, None] * q[None, :] + D = len(Cdict) + d = Y.shape[-1] + w = unif(D) # Initialize with uniform weights + ns = C.shape[-1] + nt = Cdict.shape[-1] + + # modeling (C,Y) + Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) + Yembedded = np.sum(w[:, None, None] * Ydict, axis=0) + + # constants depending on q + const_q = q[:, None] * q[None, :] + diag_q = np.diag(q) + # Trackers for BCD convergence + convergence_criterion = np.inf + current_loss = 10**15 + outer_count = 0 + Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix + + while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): + previous_loss = current_loss + + # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w + Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T) + M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features + T, log = fused_gromov_wasserstein( + M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, + max_iter=max_iter_inner, tol_rel=tol_inner, tol_abs=0., armijo=False, G0=T, log=True, symmetric=symmetric, **kwargs) + current_loss = log['fgw_dist'] + if reg != 0: + current_loss -= reg * np.sum(w**2) + + # 2. Solve linear unmixing problem over w with a fixed transport plan T + w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, + T, p, q, const_q, diag_q, current_loss, alpha, reg, + tol=tol_inner, max_iter=max_iter_inner, **kwargs) + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: + convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) + outer_count += 1 + + return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss) + + +def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs): + r""" + Returns for a fixed admissible transport plan, + the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})` + + .. math:: + \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2 + + Such that : + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}` + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]_, algorithm 7. + + Parameters + ---------- + + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Y : array-like, shape (ns, d) + Feature matrix. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,nt). + Ydict : list of D array-like, shape (nt,d) + Feature matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,d). + Cembedded: array-like, shape (nt,nt) + Embedded structure of (C,Y) onto the dictionary + Yembedded: array-like, shape (nt,d) + Embedded features of (C,Y) onto the dictionary + w: array-like, shape (n_D,) + Linear unmixing of (C,Y) onto (Cdict,Ydict) + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution. + diag_q: array-like, shape (nt,nt) + diagonal matrix with values of q on the diagonal. + T: array-like, shape (ns,nt) + fixed transport plan between (C,Y) and its model + p : array-like, shape (ns,) + Distribution in the source space (C,Y). + q : array-like, shape (nt,) + Distribution in the embedding space depicted by the dictionary. + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. + + Returns + ------- + w: ndarray (D,) + linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing. + """ + convergence_criterion = np.inf + current_loss = starting_loss + count = 0 + const_TCT = np.transpose(C.dot(T)).dot(T) + ones_ns_d = np.ones(Y.shape) + + while (convergence_criterion > tol) and (count < max_iter): + previous_loss = current_loss + + # 1) Compute gradient at current point w + # structure + grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) + # feature + grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2)) + grad_w -= reg * w + grad_w *= 2 + + # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w + min_ = np.min(grad_w) + x = (grad_w == min_).astype(np.float64) + x /= np.sum(x) + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg) + + # 4) Updates: w <-- (1-gamma)*w + gamma*x + w += gamma * (x - w) + Cembedded += gamma * Cembedded_diff + Yembedded += gamma * Yembedded_diff + current_loss += a * (gamma**2) + b * gamma + + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: + convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) + count += 1 + + return w, Cembedded, Yembedded, current_loss + + +def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs): + r""" + Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing + .. math:: + \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2 + + + Such that : + + - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` + + Parameters + ---------- + + w : array-like, shape (D,) + Unmixing. + grad_w : array-like, shape (D, D) + Gradient of the reconstruction loss with respect to w. + x: array-like, shape (D,) + Conditional gradient direction. + Y: arrat-like, shape (ns,d) + Feature matrix of the input space + Cdict : list of D array-like, shape (nt, nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,nt). + Ydict : list of D array-like, shape (nt, d) + Feature matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,d). + Cembedded: array-like, shape (nt, nt) + Embedded structure of (C,Y) onto the dictionary + Yembedded: array-like, shape (nt, d) + Embedded features of (C,Y) onto the dictionary + T: array-like, shape (ns, nt) + Fixed transport plan between (C,Y) and its current model. + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + const_TCT: array-like, shape (nt, nt) + :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. + ones_ns_d: array-like, shape (ns, d) + :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations. + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. + + Returns + ------- + gamma: float + Optimal value for the line-search step + a: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + b: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + Cembedded_diff: numpy array, shape (nt, nt) + Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + Yembedded_diff: numpy array, shape (nt, nt) + Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + """ + # polynomial coefficients from quadratic objective (with respect to w) on structures + Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) + Cembedded_diff = Cembedded_x - Cembedded + trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) + trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) + # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss + a_gw = trace_diffx - trace_diffw + b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) + + # polynomial coefficient from quadratic objective (with respect to w) on features + Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0) + Yembedded_diff = Yembedded_x - Yembedded + # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss + a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T) + b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T))) + + a = alpha * a_gw + (1 - alpha) * a_w + b = alpha * b_gw + (1 - alpha) * b_w + if reg != 0: + a -= reg * np.sum((x - w)**2) + b -= 2 * reg * np.sum(w * (x - w)) + if a > 0: + gamma = min(1, max(0, -b / (2 * a))) + elif a + b < 0: + gamma = 1 + else: + gamma = 0 + + return gamma, a, b, Cembedded_diff, Yembedded_diff diff --git a/ot/gromov/_estimators.py b/ot/gromov/_estimators.py new file mode 100644 index 0000000..0a29a91 --- /dev/null +++ b/ot/gromov/_estimators.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- +""" +Gromov-Wasserstein and Fused-Gromov-Wasserstein stochastic estimators. +""" + +# Author: Rémi Flamary +# Tanguy Kerdoncuff +# +# License: MIT License + +import numpy as np + + +from ..bregman import sinkhorn +from ..utils import list_to_array, check_random_state +from ..lp import emd_1d, emd +from ..backend import get_backend + + +def GW_distance_estimation(C1, C2, p, q, loss_fun, T, + nb_samples_p=None, nb_samples_q=None, std=True, random_state=None): + r""" + Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + with a fixed transport plan :math:`\mathbf{T}`. + + The function gives an unbiased approximation of the following equation: + + .. math:: + + GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - `L` : Loss function to account for the misfit between the similarity matrices + - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}` + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + T : csr or array-like, shape (ns, nt) + Transport plan matrix, either a sparse csr or a dense matrix + nb_samples_p : int, optional + `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}` + nb_samples_q : int, optional + `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first + std : bool, optional + Standard deviation associated with the prediction of the gromov-wasserstein cost + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + : float + Gromov-wasserstein cost + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + generator = check_random_state(random_state) + + len_p = p.shape[0] + len_q = q.shape[0] + + # It is always better to sample from the biggest distribution first. + if len_p < len_q: + p, q = q, p + len_p, len_q = len_q, len_p + C1, C2 = C2, C1 + T = T.T + + if nb_samples_p is None: + if nx.issparse(T): + # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced + nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p) + else: + nb_samples_p = len_p + else: + # The number of sample along the first dimension is without replacement. + nb_samples_p = min(nb_samples_p, len_p) + if nb_samples_q is None: + nb_samples_q = 1 + if std: + nb_samples_q = max(2, nb_samples_q) + + index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + + index_i = generator.choice( + len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False + ) + index_j = generator.choice( + len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False + ) + + for i in range(nb_samples_p): + if nx.issparse(T): + T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,)) + T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,)) + else: + T_indexi = T[index_i[i], :] + T_indexj = T[index_j[i], :] + # For each of the row sampled, the column is sampled. + index_k[i] = generator.choice( + len_q, + size=nb_samples_q, + p=nx.to_numpy(T_indexi / nx.sum(T_indexi)), + replace=True + ) + index_l[i] = generator.choice( + len_q, + size=nb_samples_q, + p=nx.to_numpy(T_indexj / nx.sum(T_indexj)), + replace=True + ) + + list_value_sample = nx.stack([ + loss_fun( + C1[np.ix_(index_i, index_j)], + C2[np.ix_(index_k[:, n], index_l[:, n])] + ) for n in range(nb_samples_q) + ], axis=2) + + if std: + std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5 + return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) + else: + return nx.mean(list_value_sample) + + +def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, + alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None): + r""" + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe. + This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations. + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + alpha : float + Step of the Frank-Wolfe algorithm, should be between 0 and 1 + max_iter : int, optional + Max number of iterations + threshold_plan : float, optional + Deleting very small values in the transport plan. If above zero, it violates the marginal constraints. + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + len_p = p.shape[0] + len_q = q.shape[0] + + generator = check_random_state(random_state) + + index = np.zeros(2, dtype=int) + + # Initialize with default marginal + index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) + index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q)) + T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)) + + best_gw_dist_estimated = np.inf + for cpt in range(max_iter): + index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) + T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,)) + index[1] = generator.choice( + len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0)) + ) + + if alpha == 1: + T = nx.tocsr( + emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) + ) + else: + new_T = nx.tocsr( + emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) + ) + T = (1 - alpha) * T + alpha * new_T + # To limit the number of non 0, the values below the threshold are set to 0. + T = nx.eliminate_zeros(T, threshold=threshold_plan) + + if cpt % 10 == 0 or cpt == (max_iter - 1): + gw_dist_estimated = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, std=False, random_state=generator + ) + + if gw_dist_estimated < best_gw_dist_estimated: + best_gw_dist_estimated = gw_dist_estimated + best_T = nx.copy(T) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated)) + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=best_T, random_state=generator + ) + return best_T, log + return best_T + + +def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, + nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False, + random_state=None): + r""" + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe. + This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver. + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}` + Loss function used for the distance, the transport plan does not depend on the loss function + nb_samples_grad : int + Number of samples to approximate the gradient + epsilon : float + Weight of the Kullback-Leibler regularization + max_iter : int, optional + Max number of iterations + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + len_p = p.shape[0] + len_q = q.shape[0] + + generator = check_random_state(random_state) + + # The most natural way to define nb_sample is with a simple integer. + if isinstance(nb_samples_grad, int): + if nb_samples_grad > len_p: + # As the sampling along the first dimension is done without replacement, the rest is reported to the second + # dimension. + nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p + else: + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1 + else: + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad + T = nx.outer(p, q) + # continue_loop allows to stop the loop if there is several successive small modification of T. + continue_loop = 0 + + # The gradient of GW is more complex if the two matrices are not symmetric. + C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) + + for cpt in range(max_iter): + index0 = generator.choice( + len_p, size=nb_samples_grad_p, p=nx.to_numpy(p), replace=False + ) + Lik = 0 + for i, index0_i in enumerate(index0): + index1 = generator.choice( + len_q, size=nb_samples_grad_q, + p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])), + replace=False + ) + # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. + if (not C_are_symmetric) and generator.rand(1) > 0.5: + Lik += nx.mean(loss_fun( + C1[:, [index0[i]] * nb_samples_grad_q][:, None, :], + C2[:, index1][None, :, :] + ), axis=2) + else: + Lik += nx.mean(loss_fun( + C1[[index0[i]] * nb_samples_grad_q, :][:, :, None], + C2[index1, :][:, None, :] + ), axis=0) + + max_Lik = nx.max(Lik) + if max_Lik == 0: + continue + # This division by the max is here to facilitate the choice of epsilon. + Lik /= max_Lik + + if epsilon > 0: + # Set to infinity all the numbers below exp(-200) to avoid log of 0. + log_T = nx.log(nx.clip(T, np.exp(-200), 1)) + log_T = nx.where(log_T == -200, -np.inf, log_T) + Lik = Lik - epsilon * log_T + + try: + new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon) + except (RuntimeWarning, UserWarning): + print("Warning catched in Sinkhorn: Return last stable T") + break + else: + new_T = emd(a=p, b=q, M=Lik) + + change_T = nx.mean((T - new_T) ** 2) + if change_T <= 10e-20: + continue_loop += 1 + if continue_loop > 100: # Number max of low modifications of T + T = nx.copy(new_T) + break + else: + continue_loop = 0 + + if verbose and cpt % 10 == 0: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, change_T)) + T = nx.copy(new_T) + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( + C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, random_state=generator + ) + return T, log + return T diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py new file mode 100644 index 0000000..c6e4076 --- /dev/null +++ b/ot/gromov/_gw.py @@ -0,0 +1,978 @@ +# -*- coding: utf-8 -*- +""" +Gromov-Wasserstein and Fused-Gromov-Wasserstein conditional gradient solvers. +""" + +# Author: Erwan Vautier +# Nicolas Courty +# Rémi Flamary +# Titouan Vayer +# Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np + + +from ..utils import dist, UndefinedParameter, list_to_array +from ..optim import cg, line_search_armijo, solve_1d_linesearch_quad +from ..utils import check_random_state +from ..backend import get_backend, NumpyBackend + +from ._utils import init_matrix, gwloss, gwggrad +from ._utils import update_square_loss, update_kl_loss + + +def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + + The function solves the following optimization problem: + + .. math:: + \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity matrices + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss' + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + armijo : bool, optional + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Coupling between the two spaces that minimizes: + + :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}` + log : dict + Convergence information and loss. + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the + metric approach to object matching. Foundations of computational + mathematics 11.4 (2011): 417-487. + + .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein + distance between networks and stable network invariants. + Information and Inference: A Journal of the IMA, 8(4), 757-787. + """ + p, q = list_to_array(p, q) + p0, q0, C10, C20 = p, q, C1, C2 + if G0 is None: + nx = get_backend(p0, q0, C10, C20) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, G0_) + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + if symmetric is None: + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) + # cg for GW is implemented using numpy on CPU + np_ = NumpyBackend() + + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_) + + def f(G): + return gwloss(constC, hC1, hC2, G, np_) + + if symmetric: + def df(G): + return gwggrad(constC, hC1, hC2, G, np_) + else: + constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, np_) + + def df(G): + return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) + if loss_fun == 'kl_loss': + armijo = True # there is no closed form line-search with KL + + if armijo: + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) + else: + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_, **kwargs) + if log: + res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log + else: + return nx.from_numpy(cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10) + + +def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + + The function solves the following optimization problem: + + .. math:: + GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity + matrices + + Note that when using backends, this loss function is differentiable wrt the + matrices (C1, C2) and weights (p, q) for quadratic loss using the gradients from [38]_. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space. + q : array-like, shape (nt,) + Distribution in the target space. + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss' + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + armijo : bool, optional + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + gw_dist : float + Gromov-Wasserstein distance + log : dict + convergence information and Coupling marix + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the + metric approach to object matching. Foundations of computational + mathematics 11.4 (2011): 417-487. + + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + + .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein + distance between networks and stable network invariants. + Information and Inference: A Journal of the IMA, 8(4), 757-787. + """ + # simple get_backend as the full one will be handled in gromov_wasserstein + nx = get_backend(C1, C2) + + T, log_gw = gromov_wasserstein( + C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0, + max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) + + log_gw['T'] = T + gw = log_gw['gw_dist'] + + if loss_fun == 'square_loss': + gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) + gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) + gw = nx.set_gradients(gw, (p, q, C1, C2), + (log_gw['u'] - nx.mean(log_gw['u']), + log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2)) + + if log: + return gw, log_gw + else: + return gw + + +def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5, + armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Computes the FGW transport between two graphs (see :ref:`[24] `) + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix representative of the structure in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix representative of the structure in the target space + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + loss_fun : str, optional + Loss function used for the solver + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + armijo : bool, optional + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + log : bool, optional + record log if True + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + gamma : array-like, shape (`ns`, `nt`) + Optimal transportation matrix for the given parameters. + log : dict + Log dictionary return only if log==True in parameters. + + + .. _references-fused-gromov-wasserstein: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + + .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein + distance between networks and stable network invariants. + Information and Inference: A Journal of the IMA, 8(4), 757-787. + """ + p, q = list_to_array(p, q) + p0, q0, C10, C20, M0 = p, q, C1, C2, M + if G0 is None: + nx = get_backend(p0, q0, C10, C20, M0) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, M0, G0_) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + M = nx.to_numpy(M0) + + if symmetric is None: + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) + # cg for GW is implemented using numpy on CPU + np_ = NumpyBackend() + + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_) + + def f(G): + return gwloss(constC, hC1, hC2, G, np_) + + if symmetric: + def df(G): + return gwggrad(constC, hC1, hC2, G, np_) + else: + constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, np_) + + def df(G): + return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) + + if loss_fun == 'kl_loss': + armijo = True # there is no closed form line-search with KL + + if armijo: + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) + else: + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs) + if log: + res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log + else: + return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10) + + +def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5, + armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Computes the FGW distance between two graphs see (see :ref:`[24] `) + + .. math:: + \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + + \mathbf{\gamma} &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + The algorithm used for solving the problem is conditional gradient as + discussed in :ref:`[24] ` + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + + Note that when using backends, this loss function is differentiable wrt the + matrices (C1, C2, M) and weights (p, q) for quadratic loss using the gradients from [38]_. + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix representative of the structure in the source space. + C2 : array-like, shape (nt, nt) + Metric cost matrix representative of the structure in the target space. + p : array-like, shape (ns,) + Distribution in the source space. + q : array-like, shape (nt,) + Distribution in the target space. + loss_fun : str, optional + Loss function used for the solver. + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + armijo : bool, optional + If True the step of the line-search is found via an armijo research. + Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + log : bool, optional + Record log if True. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + Parameters can be directly passed to the ot.optim.cg solver. + + Returns + ------- + fgw-distance : float + Fused gromov wasserstein distance for the given parameters. + log : dict + Log dictionary return only if log==True in parameters. + + + .. _references-fused-gromov-wasserstein2: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + + .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein + distance between networks and stable network invariants. + Information and Inference: A Journal of the IMA, 8(4), 757-787. + """ + nx = get_backend(C1, C2, M) + + T, log_fgw = fused_gromov_wasserstein( + M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True, + max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) + + fgw_dist = log_fgw['fgw_dist'] + log_fgw['T'] = T + + if loss_fun == 'square_loss': + gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) + gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) + fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), + (log_fgw['u'] - nx.mean(log_fgw['u']), + log_fgw['v'] - nx.mean(log_fgw['v']), + alpha * gC1, alpha * gC2, (1 - alpha) * T)) + + if log: + return fgw_dist, log_fgw + else: + return fgw_dist + + +def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, + alpha_min=None, alpha_max=None, nx=None, **kwargs): + """ + Solve the linesearch in the FW iterations + + Parameters + ---------- + + G : array-like, shape(ns,nt) + The transport map at a given iteration of the FW + deltaG : array-like (ns,nt) + Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration + cost_G : float + Value of the cost at `G` + C1 : array-like (ns,ns), optional + Structure matrix in the source domain. + C2 : array-like (nt,nt), optional + Structure matrix in the target domain. + M : array-like (ns,nt) + Cost matrix between the features. + reg : float + Regularization parameter. + alpha_min : float, optional + Minimum value for alpha + alpha_max : float, optional + Maximum value for alpha + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + alpha : float + The optimal step size of the FW + fc : int + nb of function call. Useless here + cost_G : float + The value of the cost for the next iteration + + + .. _references-solve-linesearch: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + if nx is None: + G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M) + + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(G, deltaG, C1, C2) + else: + nx = get_backend(G, deltaG, C1, C2, M) + + dot = nx.dot(nx.dot(C1, deltaG), C2.T) + a = -2 * reg * nx.sum(dot * deltaG) + b = nx.sum(M * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) + + alpha = solve_1d_linesearch_quad(a, b) + if alpha_min is not None or alpha_max is not None: + alpha = np.clip(alpha, alpha_min, alpha_max) + + # the new cost is deduced from the line search quadratic function + cost_G = cost_G + a * (alpha ** 2) + b * alpha + + return alpha, 1, cost_G + + +def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=False, + max_iter=1000, tol=1e-9, verbose=False, log=False, + init_C=None, random_state=None, **kwargs): + r""" + Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` + + The function solves the following optimization problem with block coordinate descent: + + .. math:: + + \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) + + Where : + + - :math:`\mathbf{C}_s`: metric cost matrix + - :math:`\mathbf{p}_s`: distribution + + Parameters + ---------- + N : int + Size of the targeted barycenter + Cs : list of S array-like of shape (ns, ns) + Metric cost matrices + ps : list of S array-like of shape (ns,) + Sample weights in the `S` spaces + p : array-like, shape (N,) + Weights in the targeted barycenter + lambdas : list of float + List of the `S` spaces' weights + loss_fun : callable + tensor-matrix multiplication function based on specific loss function + symmetric : bool, optional. + Either structures are to be assumed symmetric or not. Default value is True. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + update : callable + function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates + :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings + calculated at each iteration + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on relative error (>0) + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. + init_C : bool | array-like, shape(N,N) + Random initial value for the :math:`\mathbf{C}` matrix provided by user. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + C : array-like, shape (`N`, `N`) + Similarity matrix in the barycenter space (permutated arbitrarily) + log : dict + Log dictionary of error during iterations. Return only if `log=True` in parameters. + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + p = list_to_array(p) + nx = get_backend(*Cs, *ps, p) + + S = len(Cs) + + # Initialization of C : random SPD matrix (if not provided by user) + if init_C is None: + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) + C = dist(xalea, xalea) + C /= C.max() + C = nx.from_numpy(C, type_as=p) + else: + C = init_C + + if loss_fun == 'kl_loss': + armijo = True + + cpt = 0 + err = 1 + + error = [] + + while (err > tol and cpt < max_iter): + Cprev = C + + T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, + max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)] + if loss_fun == 'square_loss': + C = update_square_loss(p, lambdas, T, Cs) + + elif loss_fun == 'kl_loss': + C = update_kl_loss(p, lambdas, T, Cs) + + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = nx.norm(C - Cprev) + error.append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + + if log: + return C, {"err": error} + else: + return C + + +def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, + p=None, loss_fun='square_loss', armijo=False, symmetric=True, max_iter=100, tol=1e-9, + verbose=False, log=False, init_C=None, init_X=None, random_state=None, **kwargs): + r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` + + Parameters + ---------- + N : int + Desired number of samples of the target barycenter + Ys: list of array-like, each element has shape (ns,d) + Features of all samples + Cs : list of array-like, each element has shape (ns,ns) + Structure matrices of all samples + ps : list of array-like, each element has shape (ns,) + Masses of all samples. + lambdas : list of float + List of the `S` spaces' weights + alpha : float + Alpha parameter for the fgw distance + fixed_structure : bool + Whether to fix the structure of the barycenter during the updates + fixed_features : bool + Whether to fix the feature of the barycenter during the updates + loss_fun : str + Loss function used for the solver either 'square_loss' or 'kl_loss' + symmetric : bool, optional + Either structures are to be assumed symmetric or not. Default value is True. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on relative error (>0) + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. + init_C : array-like, shape (N,N), optional + Initialization for the barycenters' structure matrix. If not set + a random init is used. + init_X : array-like, shape (N,d), optional + Initialization for the barycenters' features. If not set a + random init is used. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + X : array-like, shape (`N`, `d`) + Barycenters' features + C : array-like, shape (`N`, `N`) + Barycenters' structure matrix + log : dict + Only returned when log=True. It contains the keys: + + - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices + - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`) + + + .. _references-fgw-barycenters: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + Cs = list_to_array(*Cs) + ps = list_to_array(*ps) + Ys = list_to_array(*Ys) + p = list_to_array(p) + nx = get_backend(*Cs, *Ys, *ps) + + S = len(Cs) + d = Ys[0].shape[1] # dimension on the node features + if p is None: + p = nx.ones(N, type_as=Cs[0]) / N + + if fixed_structure: + if init_C is None: + raise UndefinedParameter('If C is fixed it must be initialized') + else: + C = init_C + else: + if init_C is None: + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) + C = dist(xalea, xalea) + C = nx.from_numpy(C, type_as=ps[0]) + else: + C = init_C + + if fixed_features: + if init_X is None: + raise UndefinedParameter('If X is fixed it must be initialized') + else: + X = init_X + else: + if init_X is None: + X = nx.zeros((N, d), type_as=ps[0]) + else: + X = init_X + + T = [nx.outer(p, q) for q in ps] + + Ms = [dist(X, Ys[s]) for s in range(len(Ys))] + + if loss_fun == 'kl_loss': + armijo = True + + cpt = 0 + err_feature = 1 + err_structure = 1 + + if log: + log_ = {} + log_['err_feature'] = [] + log_['err_structure'] = [] + log_['Ts_iter'] = [] + + while ((err_feature > tol or err_structure > tol) and cpt < max_iter): + Cprev = C + Xprev = X + + if not fixed_features: + Ys_temp = [y.T for y in Ys] + X = update_feature_matrix(lambdas, Ys_temp, T, p).T + + Ms = [dist(X, Ys[s]) for s in range(len(Ys))] + + if not fixed_structure: + if loss_fun == 'square_loss': + T_temp = [t.T for t in T] + C = update_structure_matrix(p, lambdas, T_temp, Cs) + + T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, + max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)] + + # T is N,ns + err_feature = nx.norm(X - nx.reshape(Xprev, (N, d))) + err_structure = nx.norm(C - Cprev) + if log: + log_['err_feature'].append(err_feature) + log_['err_structure'].append(err_structure) + log_['Ts_iter'].append(T) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err_structure)) + print('{:5d}|{:8e}|'.format(cpt, err_feature)) + + cpt += 1 + + if log: + log_['T'] = T # from target to Ys + log_['p'] = p + log_['Ms'] = Ms + + if log: + return X, C, log_ + else: + return X, C + + +def update_structure_matrix(p, lambdas, T, Cs): + r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. + + It is calculated at each iteration + + Parameters + ---------- + p : array-like, shape (N,) + Masses in the targeted barycenter. + lambdas : list of float + List of the `S` spaces' weights. + T : list of S array-like of shape (ns, N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape (ns, ns) + Metric cost matrices. + + Returns + ------- + C : array-like, shape (`nt`, `nt`) + Updated :math:`\mathbf{C}` matrix. + """ + p = list_to_array(p) + T = list_to_array(*T) + Cs = list_to_array(*Cs) + nx = get_backend(*Cs, *T, p) + + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + return tmpsum / ppt + + +def update_feature_matrix(lambdas, Ys, Ts, p): + r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. + + + See "Solving the barycenter problem with Block Coordinate Descent (BCD)" + in :ref:`[24] ` calculated at each iteration + + Parameters + ---------- + p : array-like, shape (N,) + masses in the targeted barycenter + lambdas : list of float + List of the `S` spaces' weights + Ts : list of S array-like, shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration + Ys : list of S array-like, shape (d,ns) + The features. + + Returns + ------- + X : array-like, shape (`d`, `N`) + + + .. _references-update-feature-matrix: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + p = list_to_array(p) + Ts = list_to_array(*Ts) + Ys = list_to_array(*Ys) + nx = get_backend(*Ys, *Ts, p) + + p = 1. / p + tmpsum = sum([ + lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :] + for s in range(len(Ts)) + ]) + return tmpsum diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py new file mode 100644 index 0000000..638bb1c --- /dev/null +++ b/ot/gromov/_semirelaxed.py @@ -0,0 +1,543 @@ +# -*- coding: utf-8 -*- +""" +Semi-relaxed Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers. +""" + +# Author: Rémi Flamary +# Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np + + +from ..utils import list_to_array, unif +from ..optim import semirelaxed_cg, solve_1d_linesearch_quad +from ..backend import get_backend + +from ._utils import init_matrix_semirelaxed, gwloss, gwggrad + + +def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Returns the semi-relaxed gromov-wasserstein divergence transport from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` + + The function solves the following optimization problem: + + .. math:: + \mathbf{srGW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma} &\geq 0 + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + + - `L`: loss function to account for the misfit between the similarity matrices + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Coupling between the two spaces that minimizes: + + :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}` + log : dict + Convergence information and loss. + + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + if loss_fun == 'kl_loss': + raise NotImplementedError() + p = list_to_array(p) + if G0 is None: + nx = get_backend(p, C1, C2) + else: + nx = get_backend(p, C1, C2, G0) + + if symmetric is None: + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + if G0 is None: + q = unif(C2.shape[0], type_as=p) + G0 = nx.outer(p, q) + else: + q = nx.sum(G0, 0) + # Check first marginal of G0 + np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08) + + constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx) + + ones_p = nx.ones(p.shape[0], type_as=p) + + def f(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + return gwloss(constC + marginal_product, hC1, hC2, G, nx) + + if symmetric: + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + return gwggrad(constC + marginal_product, hC1, hC2, G, nx) + else: + constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx) + + def df(G): + qG = nx.sum(G, 0) + marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t)) + marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) + return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, M=0., reg=1., nx=nx, **kwargs) + + if log: + res, log = semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + log['srgw_dist'] = log['loss'][-1] + return res, log + else: + return semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + + +def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Returns the semi-relaxed gromov-wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` + + The function solves the following optimization problem: + + .. math:: + srGW = \min_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma} &\geq 0 + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - `L`: loss function to account for the misfit between the similarity + matrices + + Note that when using backends, this loss function is differentiable wrt the + matrices (C1, C2) but not yet for the weights p. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Distribution in the source space. + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + srgw : float + Semi-relaxed Gromov-Wasserstein divergence + log : dict + convergence information and Coupling matrix + + References + ---------- + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + nx = get_backend(p, C1, C2) + + T, log_srgw = semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun, symmetric, log=True, G0=G0, + max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) + + q = nx.sum(T, 0) + log_srgw['T'] = T + srgw = log_srgw['srgw_dist'] + + if loss_fun == 'square_loss': + gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) + gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) + srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2)) + + if log: + return srgw, log_srgw + else: + return srgw + + +def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Computes the semi-relaxed FGW transport between two graphs (see :ref:`[48] `) + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma} &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` source weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[48] ` + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix representative of the structure in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix representative of the structure in the target space + p : array-like, shape (ns,) + Distribution in the source space + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + log : bool, optional + record log if True + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + gamma : array-like, shape (`ns`, `nt`) + Optimal transportation matrix for the given parameters. + log : dict + Log dictionary return only if log==True in parameters. + + + .. _references-semirelaxed-fused-gromov-wasserstein: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + if loss_fun == 'kl_loss': + raise NotImplementedError() + + p = list_to_array(p) + if G0 is None: + nx = get_backend(p, C1, C2, M) + else: + nx = get_backend(p, C1, C2, M, G0) + + if symmetric is None: + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + + if G0 is None: + q = unif(C2.shape[0], type_as=p) + G0 = nx.outer(p, q) + else: + q = nx.sum(G0, 0) + # Check marginals of G0 + np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08) + + constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx) + + ones_p = nx.ones(p.shape[0], type_as=p) + + def f(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + return gwloss(constC + marginal_product, hC1, hC2, G, nx) + + if symmetric: + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + return gwggrad(constC + marginal_product, hC1, hC2, G, nx) + else: + constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx) + + def df(G): + qG = nx.sum(G, 0) + marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t)) + marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) + return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return solve_semirelaxed_gromov_linesearch( + G, deltaG, cost_G, C1, C2, ones_p, M=(1 - alpha) * M, reg=alpha, nx=nx, **kwargs) + + if log: + res, log = semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + log['srfgw_dist'] = log['loss'][-1] + return res, log + else: + return semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + + +def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + r""" + Computes the semi-relaxed FGW divergence between two graphs (see :ref:`[48] `) + + .. math:: + \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + + \mathbf{\gamma} &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{p}` source weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + The algorithm used for solving the problem is conditional gradient as + discussed in :ref:`[48] ` + + Note that when using backends, this loss function is differentiable wrt the + matrices (C1, C2) but not yet for the weights p. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix representative of the structure in the source space. + C2 : array-like, shape (nt, nt) + Metric cost matrix representative of the structure in the target space. + p : array-like, shape (ns,) + Distribution in the source space. + loss_fun : str, optional + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + log : bool, optional + Record log if True. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + **kwargs : dict + Parameters can be directly passed to the ot.optim.cg solver. + + Returns + ------- + srfgw-divergence : float + Semi-relaxed Fused gromov wasserstein divergence for the given parameters. + log : dict + Log dictionary return only if log==True in parameters. + + + .. _references-semirelaxed-fused-gromov-wasserstein2: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + nx = get_backend(p, C1, C2, M) + + T, log_fgw = semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p, loss_fun, symmetric, alpha, G0, log=True, + max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) + q = nx.sum(T, 0) + srfgw_dist = log_fgw['srfgw_dist'] + log_fgw['T'] = T + + if loss_fun == 'square_loss': + gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) + gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) + srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M), + (alpha * gC1, alpha * gC2, (1 - alpha) * T)) + + if log: + return srfgw_dist, log_fgw + else: + return srfgw_dist + + +def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, + M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs): + """ + Solve the linesearch in the FW iterations + + Parameters + ---------- + + G : array-like, shape(ns,nt) + The transport map at a given iteration of the FW + deltaG : array-like (ns,nt) + Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration + cost_G : float + Value of the cost at `G` + C1 : array-like (ns,ns) + Structure matrix in the source domain. + C2 : array-like (nt,nt) + Structure matrix in the target domain. + ones_p: array-like (ns,1) + Array of ones of size ns + M : array-like (ns,nt) + Cost matrix between the features. + reg : float + Regularization parameter. + alpha_min : float, optional + Minimum value for alpha + alpha_max : float, optional + Maximum value for alpha + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + alpha : float + The optimal step size of the FW + fc : int + nb of function call. Useless here + cost_G : float + The value of the cost for the next iteration + + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2021. + """ + if nx is None: + G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M) + + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(G, deltaG, C1, C2) + else: + nx = get_backend(G, deltaG, C1, C2, M) + + qG, qdeltaG = nx.sum(G, 0), nx.sum(deltaG, 0) + dot = nx.dot(nx.dot(C1, deltaG), C2.T) + C2t_square = C2.T ** 2 + dot_qG = nx.dot(nx.outer(ones_p, qG), C2t_square) + dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), C2t_square) + a = reg * nx.sum((dot_qdeltaG - 2 * dot) * deltaG) + b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - 2 * dot) * G) + nx.sum((dot_qG - 2 * nx.dot(nx.dot(C1, G), C2.T)) * deltaG)) + alpha = solve_1d_linesearch_quad(a, b) + if alpha_min is not None or alpha_max is not None: + alpha = np.clip(alpha, alpha_min, alpha_max) + + # the new cost can be deduced from the line search quadratic function + cost_G = cost_G + a * (alpha ** 2) + b * alpha + + return alpha, 1, cost_G diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py new file mode 100644 index 0000000..e842250 --- /dev/null +++ b/ot/gromov/_utils.py @@ -0,0 +1,413 @@ +# -*- coding: utf-8 -*- +""" +Gromov-Wasserstein and Fused-Gromov-Wasserstein utils. +""" + +# Author: Erwan Vautier +# Nicolas Courty +# Rémi Flamary +# Titouan Vayer +# Cédric Vincent-Cuaz +# +# License: MIT License + + +from ..utils import list_to_array +from ..backend import get_backend + + +def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): + r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation + + Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the + selected loss function as the loss function of Gromow-Wasserstein discrepancy. + + The matrices are computed as described in Proposition 1 in :ref:`[12] ` + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{T}`: A coupling between those two spaces + + The square-loss function :math:`L(a, b) = |a - b|^2` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a^2 + + f_2(b) &= b^2 + + h_1(a) &= a + + h_2(b) &= 2b + + The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a \log(a) - a + + f_2(b) &= b + + h_1(a) &= a + + h_2(b) &= \log(b) + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Probability distribution in the source space + q : array-like, shape (nt,) + Probability distribution in the target space + loss_fun : str, optional + Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss') + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + + + .. _references-init-matrix: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + if nx is None: + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + if loss_fun == 'square_loss': + def f1(a): + return (a**2) + + def f2(b): + return (b**2) + + def h1(a): + return a + + def h2(b): + return 2 * b + elif loss_fun == 'kl_loss': + def f1(a): + return a * nx.log(a + 1e-15) - a + + def f2(b): + return b + + def h1(a): + return a + + def h2(b): + return nx.log(b + 1e-15) + + constC1 = nx.dot( + nx.dot(f1(C1), nx.reshape(p, (-1, 1))), + nx.ones((1, len(q)), type_as=q) + ) + constC2 = nx.dot( + nx.ones((len(p), 1), type_as=p), + nx.dot(nx.reshape(q, (1, -1)), f2(C2).T) + ) + constC = constC1 + constC2 + hC1 = h1(C1) + hC2 = h2(C2) + + return constC, hC1, hC2 + + +def tensor_product(constC, hC1, hC2, T, nx=None): + r"""Return the tensor for Gromov-Wasserstein fast computation + + The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` + + Parameters + ---------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + tens : array-like, shape (`ns`, `nt`) + :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result + + + .. _references-tensor-product: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + if nx is None: + constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T) + nx = get_backend(constC, hC1, hC2, T) + + A = - nx.dot( + nx.dot(hC1, T), hC2.T + ) + tens = constC + A + # tens -= tens.min() + return tens + + +def gwloss(constC, hC1, hC2, T, nx=None): + r"""Return the Loss for Gromov-Wasserstein + + The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` + + Parameters + ---------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + loss : float + Gromov Wasserstein loss + + + .. _references-gwloss: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + + tens = tensor_product(constC, hC1, hC2, T, nx) + if nx is None: + tens, T = list_to_array(tens, T) + nx = get_backend(tens, T) + + return nx.sum(tens * T) + + +def gwggrad(constC, hC1, hC2, T, nx=None): + r"""Return the gradient for Gromov-Wasserstein + + The gradient is computed as described in Proposition 2 in :ref:`[12] ` + + Parameters + ---------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + grad : array-like, shape (`ns`, `nt`) + Gromov Wasserstein gradient + + + .. _references-gwggrad: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + return 2 * tensor_product(constC, hC1, hC2, + T, nx) # [12] Prop. 2 misses a 2 factor + + +def update_square_loss(p, lambdas, T, Cs): + r""" + Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` + couplings calculated at each iteration + + Parameters + ---------- + p : array-like, shape (N,) + Masses in the targeted barycenter. + lambdas : list of float + List of the `S` spaces' weights. + T : list of S array-like of shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape(ns,ns) + Metric cost matrices. + + Returns + ---------- + C : array-like, shape (`nt`, `nt`) + Updated :math:`\mathbf{C}` matrix. + """ + T = list_to_array(*T) + Cs = list_to_array(*Cs) + p = list_to_array(p) + nx = get_backend(p, *T, *Cs) + + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + + return tmpsum / ppt + + +def update_kl_loss(p, lambdas, T, Cs): + r""" + Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration + + + Parameters + ---------- + p : array-like, shape (N,) + Weights in the targeted barycenter. + lambdas : list of float + List of the `S` spaces' weights + T : list of S array-like of shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape(ns,ns) + Metric cost matrices. + + Returns + ---------- + C : array-like, shape (`ns`, `ns`) + updated :math:`\mathbf{C}` matrix + """ + Cs = list_to_array(*Cs) + T = list_to_array(*T) + p = list_to_array(p) + nx = get_backend(p, *T, *Cs) + + tmpsum = sum([ + lambdas[s] * nx.dot( + nx.dot(T[s].T, Cs[s]), + T[s] + ) for s in range(len(T)) + ]) + ppt = nx.outer(p, p) + + return nx.exp(tmpsum / ppt) + + +def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): + r"""Return loss matrices and tensors for semi-relaxed Gromov-Wasserstein fast computation + + Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the + selected loss function as the loss function of semi-relaxed Gromow-Wasserstein discrepancy. + + The matrices are computed as described in Proposition 1 in :ref:`[12] ` + and adapted to the semi-relaxed problem where the second marginal is not a constant anymore. + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{T}`: A coupling between those two spaces + + The square-loss function :math:`L(a, b) = |a - b|^2` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a^2 + + f_2(b) &= b^2 + + h_1(a) &= a + + h_2(b) &= 2b + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + T : array-like, shape (ns, nt) + Coupling between source and target spaces + p : array-like, shape (ns,) + nx : backend, optional + If let to its default value None, a backend test will be conducted. + Returns + ------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) adapted to srGW + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + fC2t: array-like, shape (nt, nt) + :math:`\mathbf{f2}(\mathbf{C2})^\top` matrix in Eq. (6) + + + .. _references-init-matrix: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + if nx is None: + C1, C2, p = list_to_array(C1, C2, p) + nx = get_backend(C1, C2, p) + + if loss_fun == 'square_loss': + def f1(a): + return (a**2) + + def f2(b): + return (b**2) + + def h1(a): + return a + + def h2(b): + return 2 * b + + constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))), + nx.ones((1, C2.shape[0]), type_as=p)) + + hC1 = h1(C1) + hC2 = h2(C2) + fC2t = f2(C2).T + return constC, hC1, hC2, fC2t diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 7d0640f..2ff02ab 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Solvers for the original linear program OT problem +Solvers for the original linear program OT problem. """ diff --git a/ot/optim.py b/ot/optim.py index 5a1d605..201f898 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- """ -Generic solvers for regularized OT +Generic solvers for regularized OT or its semi-relaxed version. """ # Author: Remi Flamary # Titouan Vayer -# +# Cédric Vincent-Cuaz # License: MIT License import numpy as np @@ -27,7 +27,7 @@ with warnings.catch_warnings(): def line_search_armijo( f, xk, pk, gfk, old_fval, args=(), c1=1e-4, - alpha0=0.99, alpha_min=None, alpha_max=None + alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs ): r""" Armijo linesearch function that works with matrices @@ -57,7 +57,8 @@ def line_search_armijo( minimum value for alpha alpha_max : float, optional maximum value for alpha - + nx : backend, optional + If let to its default value None, a backend test will be conducted. Returns ------- alpha : float @@ -68,9 +69,9 @@ def line_search_armijo( loss value at step alpha """ - - xk, pk, gfk = list_to_array(xk, pk, gfk) - nx = get_backend(xk, pk) + if nx is None: + xk, pk, gfk = list_to_array(xk, pk, gfk) + nx = get_backend(xk, pk) if len(xk.shape) == 0: xk = nx.reshape(xk, (-1,)) @@ -98,97 +99,38 @@ def line_search_armijo( return float(alpha), fc[0], phi1 -def solve_linesearch( - cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None, - reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None -): - """ - Solve the linesearch in the FW iterations - - Parameters - ---------- - cost : method - Cost in the FW for the linesearch - G : array-like, shape(ns,nt) - The transport map at a given iteration of the FW - deltaG : array-like (ns,nt) - Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration - Mi : array-like (ns,nt) - Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost - f_val : float - Value of the cost at `G` - armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. - C1 : array-like (ns,ns), optional - Structure matrix in the source domain. Only used and necessary when armijo=False - C2 : array-like (nt,nt), optional - Structure matrix in the target domain. Only used and necessary when armijo=False - reg : float, optional - Regularization parameter. Only used and necessary when armijo=False - Gc : array-like (ns,nt) - Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False - constC : array-like (ns,nt) - Constant for the gromov cost. See :ref:`[24] `. Only used and necessary when armijo=False - M : array-like (ns,nt), optional - Cost matrix between the features. Only used and necessary when armijo=False - alpha_min : float, optional - Minimum value for alpha - alpha_max : float, optional - Maximum value for alpha - - Returns - ------- - alpha : float - The optimal step size of the FW - fc : int - nb of function call. Useless here - f_val : float - The value of the cost for the next iteration +def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None, + numItermax=200, stopThr=1e-9, + stopThr2=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the general regularized OT problem or its semi-relaxed version with + conditional gradient or generalized conditional gradient depending on the + provided linear program solver. + The function solves the following optimization problem if set as a conditional gradient: - .. _references-solve-linesearch: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - """ - if armijo: - alpha, fc, f_val = line_search_armijo( - cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max - ) - else: # requires symetric matrices - G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M) - if isinstance(M, int) or isinstance(M, float): - nx = get_backend(G, deltaG, C1, C2, constC) - else: - nx = get_backend(G, deltaG, C1, C2, constC, M) + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_1} \cdot f(\gamma) - dot = nx.dot(nx.dot(C1, deltaG), C2) - a = -2 * reg * nx.sum(dot * deltaG) - b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG)) - c = cost(G) + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - alpha = solve_1d_linesearch_quad(a, b, c) - if alpha_min is not None or alpha_max is not None: - alpha = np.clip(alpha, alpha_min, alpha_max) - fc = None - f_val = cost(G + alpha * deltaG) + \gamma^T \mathbf{1} &= \mathbf{b} (optional constraint) - return alpha, fc, f_val + \gamma &\geq 0 + where : + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) -def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, - stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): - r""" - Solve the general regularized OT problem with conditional gradient + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` - The function solves the following optimization problem: + The function solves the following optimization problem if set a generalized conditional gradient: .. math:: \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg} \cdot f(\gamma) + \mathrm{reg_1}\cdot f(\gamma) + \mathrm{reg_2}\cdot\Omega(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} @@ -197,29 +139,39 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, \gamma &\geq 0 where : - - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - - :math:`f` is the regularization term (and `df` is its gradient) - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - - The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] ` Parameters ---------- a : array-like, shape (ns,) samples weights in the source domain b : array-like, shape (nt,) - samples in the target domain + samples weights in the target domain M : array-like, shape (ns, nt) loss matrix - reg : float + f : function + Regularization function taking a transportation matrix as argument + df: function + Gradient of the regularization function taking a transportation matrix as argument + reg1 : float Regularization term >0 + reg2 : float, + Entropic Regularization term >0. Ignored if set to None. + lp_solver: function, + linear program solver for direction finding of the (generalized) conditional gradient. + If set to emd will solve the general regularized OT problem using cg. + If set to lp_semi_relaxed_OT will solve the general regularized semi-relaxed OT problem using cg. + If set to sinkhorn will solve the general regularized OT problem using generalized cg. + line_search: function, + Function to find the optimal step. Currently used instances are: + line_search_armijo (generic solver). solve_gromov_linesearch for (F)GW problem. + solve_semirelaxed_gromov_linesearch for sr(F)GW problem. gcg_linesearch for the Generalized cg. G0 : array-like, shape (ns,nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations - numItermaxEmd : int, optional - Max number of iterations for emd stopThr : float, optional Stop threshold on the relative variation (>0) stopThr2 : float, optional @@ -240,16 +192,20 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, .. _references-cg: + .. _references_gcg: References ---------- .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. + See Also -------- ot.lp.emd : Unregularized optimal ransport ot.bregman.sinkhorn : Entropic regularized optimal transport - """ a, b, M, G0 = list_to_array(a, b, M, G0) if isinstance(M, int) or isinstance(M, float): @@ -265,42 +221,45 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, if G0 is None: G = nx.outer(a, b) else: - G = G0 - - def cost(G): - return nx.sum(M * G) + reg * f(G) + # to not change G0 in place. + G = nx.copy(G0) - f_val = cost(G) + if reg2 is None: + def cost(G): + return nx.sum(M * G) + reg1 * f(G) + else: + def cost(G): + return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G)) + cost_G = cost(G) if log: - log['loss'].append(f_val) + log['loss'].append(cost_G) it = 0 if verbose: print('{:5s}|{:12s}|{:8s}|{:8s}'.format( 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, 0, 0)) while loop: it += 1 - old_fval = f_val - + old_cost_G = cost_G # problem linearization - Mi = M + reg * df(G) + Mi = M + reg1 * df(G) + + if not (reg2 is None): + Mi = Mi + reg2 * (1 + nx.log(G)) # set M positive - Mi += nx.min(Mi) + Mi = Mi + nx.min(Mi) # solve linear program - Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True) + Gc, innerlog_ = lp_solver(a, b, Mi, **kwargs) + # line search deltaG = Gc - G - # line search - alpha, fc, f_val = solve_linesearch( - cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, - alpha_min=0., alpha_max=1., **kwargs - ) + alpha, fc, cost_G = line_search(cost, G, deltaG, Mi, cost_G, **kwargs) G = G + alpha * deltaG @@ -308,29 +267,197 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, if it >= numItermax: loop = 0 - abs_delta_fval = abs(f_val - old_fval) - relative_delta_fval = abs_delta_fval / abs(f_val) - if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: + abs_delta_cost_G = abs(cost_G - old_cost_G) + relative_delta_cost_G = abs_delta_cost_G / abs(cost_G) + if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2: loop = 0 if log: - log['loss'].append(f_val) + log['loss'].append(cost_G) if verbose: if it % 20 == 0: print('{:5s}|{:12s}|{:8s}|{:8s}'.format( 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, relative_delta_cost_G, abs_delta_cost_G)) if log: - log.update(logemd) + log.update(innerlog_) return G, log else: return G +def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, + numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9, + verbose=False, log=False, **kwargs): + r""" + Solve the general regularized OT problem with conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot f(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` + + + Parameters + ---------- + a : array-like, shape (ns,) + samples weights in the source domain + b : array-like, shape (nt,) + samples in the target domain + M : array-like, shape (ns, nt) + loss matrix + reg : float + Regularization term >0 + G0 : array-like, shape (ns,nt), optional + initial guess (default is indep joint density) + line_search: function, + Function to find the optimal step. + Default is line_search_armijo. + numItermax : int, optional + Max number of iterations + numItermaxEmd : int, optional + Max number of iterations for emd + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + **kwargs : dict + Parameters for linesearch + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-cg: + References + ---------- + + .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + + See Also + -------- + ot.lp.emd : Unregularized optimal ransport + ot.bregman.sinkhorn : Entropic regularized optimal transport + + """ + + def lp_solver(a, b, M, **kwargs): + return emd(a, b, M, numItermaxEmd, log=True) + + return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, + numItermax=numItermax, stopThr=stopThr, + stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) + + +def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, + numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the general regularized and semi-relaxed OT problem with conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot f(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma &\geq 0 + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` + + + Parameters + ---------- + a : array-like, shape (ns,) + samples weights in the source domain + b : array-like, shape (nt,) + currently estimated samples weights in the target domain + M : array-like, shape (ns, nt) + loss matrix + reg : float + Regularization term >0 + G0 : array-like, shape (ns,nt), optional + initial guess (default is indep joint density) + line_search: function, + Function to find the optimal step. + Default is the armijo line-search. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + **kwargs : dict + Parameters for linesearch + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-cg: + References + ---------- + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2021. + + """ + + nx = get_backend(a, b) + + def lp_solver(a, b, Mi, **kwargs): + # get minimum by rows as binary mask + Gc = nx.ones(1, type_as=a) * (Mi == nx.reshape(nx.min(Mi, axis=1), (-1, 1))) + Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1)) + # return by default an empty inner_log + return Gc, {} + + return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, + numItermax=numItermax, stopThr=stopThr, + stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) + + def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, - numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False): + numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): r""" Solve the general regularized OT problem with the generalized conditional gradient @@ -403,81 +530,18 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, ot.optim.cg : conditional gradient """ - a, b, M, G0 = list_to_array(a, b, M, G0) - nx = get_backend(a, b, M) - - loop = 1 - - if log: - log = {'loss': []} - - if G0 is None: - G = nx.outer(a, b) - else: - G = G0 - - def cost(G): - return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G) - - f_val = cost(G) - if log: - log['loss'].append(f_val) - - it = 0 - - if verbose: - print('{:5s}|{:12s}|{:8s}|{:8s}'.format( - 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) - while loop: - - it += 1 - old_fval = f_val - - # problem linearization - Mi = M + reg2 * df(G) - - # solve linear program with Sinkhorn - # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax) - Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax) - - deltaG = Gc - G - - # line search - dcost = Mi + reg1 * (1 + nx.log(G)) # ?? - alpha, fc, f_val = line_search_armijo( - cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1. - ) - - G = G + alpha * deltaG - - # test convergence - if it >= numItermax: - loop = 0 - - abs_delta_fval = abs(f_val - old_fval) - relative_delta_fval = abs_delta_fval / abs(f_val) + def lp_solver(a, b, Mi, **kwargs): + return sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax, log=True, **kwargs) - if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: - loop = 0 + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs) - if log: - log['loss'].append(f_val) + return generic_conditional_gradient(a, b, M, f, df, reg2, reg1, lp_solver, line_search, G0=G0, + numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) - if verbose: - if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}|{:8s}'.format( - 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) - if log: - return G, log - else: - return G - - -def solve_1d_linesearch_quad(a, b, c): +def solve_1d_linesearch_quad(a, b): r""" For any convex or non-convex 1d quadratic function `f`, solve the following problem: @@ -487,7 +551,7 @@ def solve_1d_linesearch_quad(a, b, c): Parameters ---------- - a,b,c : float + a,b : float or tensors (1,) The coefficients of the quadratic function Returns @@ -495,15 +559,11 @@ def solve_1d_linesearch_quad(a, b, c): x : float The optimal value which leads to the minimal cost """ - f0 = c - df0 = b - f1 = a + f0 + df0 - if a > 0: # convex - minimum = min(1, max(0, np.divide(-b, 2.0 * a))) + minimum = min(1., max(0., -b / (2.0 * a))) return minimum else: # non convex - if f0 > f1: - return 1 + if a + b < 0: + return 1. else: - return 0 + return 0. diff --git a/test/test_gromov.py b/test/test_gromov.py index 9c85b92..cfccce7 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -3,7 +3,7 @@ # Author: Erwan Vautier # Nicolas Courty # Titouan Vayer -# Cédric Vincent-Cuaz +# Cédric Vincent-Cuaz # # License: MIT License @@ -11,18 +11,15 @@ import numpy as np import ot from ot.backend import NumpyBackend from ot.backend import torch, tf - import pytest def test_gromov(nx): n_samples = 50 # nb samples - mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) - + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) xt = xs[::-1].copy() p = ot.unif(n_samples) @@ -38,7 +35,7 @@ def test_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True) - Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=G0b, verbose=True)) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -51,13 +48,13 @@ def test_gromov(nx): np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) - gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True) - gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True) + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, log=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=True) gwb = nx.to_numpy(gwb) - gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', G0=G0, log=False) + gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False) gw_valb = nx.to_numpy( - ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) + ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) ) G = log['T'] @@ -77,6 +74,49 @@ def test_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_asymmetric_gromov(nx): + n_samples = 30 # nb samples + np.random.seed(0) + C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + np.random.shuffle(idx) + C2 = C1[idx, :][:, idx] + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + G, log = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True) + Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True) + Gb = nx.to_numpy(Gb) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04) + + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04) + + def test_gromov_dtype_device(nx): # setup n_samples = 50 # nb samples @@ -104,7 +144,7 @@ def test_gromov_dtype_device(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -130,7 +170,7 @@ def test_gromov_device_tf(): with tf.device("/CPU:0"): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -138,7 +178,7 @@ def test_gromov_device_tf(): # Check that everything happens on the GPU C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) assert nx.dtype_device(Gb)[1].startswith("GPU") @@ -185,6 +225,45 @@ def test_gromov2_gradients(): assert C12.shape == C12.grad.shape +def test_gw_helper_backend(nx): + n_samples = 20 # nb samples + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', armijo=False, symmetric=True, G0=G0b, log=True) + + # calls with nx=None + constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss') + + def f(G): + return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None) + + def df(G): + return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=0., reg=1., nx=None) + # feed the precomputed local optimum Gb to cg + res, log = ot.optim.cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) + + @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_gromov(nx): @@ -199,19 +278,21 @@ def test_entropic_gromov(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) - + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) C1 /= C1.max() C2 /= C2.max() - C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - G = ot.gromov.entropic_gromov_wasserstein( - C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) + G, log = ot.gromov.entropic_gromov_wasserstein( + C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-2, verbose=True, log=True) Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True + C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, + epsilon=1e-2, verbose=True, log=False )) # check constraints @@ -222,9 +303,11 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov gw, log = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', max_iter=10, epsilon=1e-2, log=True) + C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, + max_iter=10, epsilon=1e-2, log=True) gwb, logb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, 'kl_loss', max_iter=10, epsilon=1e-2, log=True) + C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=10, epsilon=1e-2, log=True) gwb = nx.to_numpy(gwb) G = log['T'] @@ -241,6 +324,45 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_asymmetric_entropic_gromov(nx): + n_samples = 10 # nb samples + np.random.seed(0) + C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + np.random.shuffle(idx) + C2 = C1[idx, :][:, idx] + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + G = ot.gromov.entropic_gromov_wasserstein( + C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, verbose=True, log=False) + Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None, + epsilon=1e-1, verbose=True, log=False + )) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + gw = ot.gromov.entropic_gromov_wasserstein2( + C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, + max_iter=10, epsilon=1e-1, log=False) + gwb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=10, epsilon=1e-1, log=False) + gwb = nx.to_numpy(gwb) + + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) + + @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_gromov_dtype_device(nx): @@ -539,8 +661,8 @@ def test_fgw(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True) Gb = nx.to_numpy(Gb) # check constraints @@ -555,8 +677,8 @@ def test_fgw(nx): np.testing.assert_allclose( Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov - fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', G0=None, alpha=0.5, log=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', G0=G0b, alpha=0.5, log=True) + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True) fgwb = nx.to_numpy(fgwb) G = log['T'] @@ -573,6 +695,82 @@ def test_fgw(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_asymmetric_fgw(nx): + n_samples = 50 # nb samples + np.random.seed(0) + C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + np.random.shuffle(idx) + C2 = C1[idx, :][:, idx] + + # add features + F1 = np.random.uniform(low=0., high=10, size=(n_samples, 1)) + F2 = F1[idx, :] + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + M = ot.dist(F1, F2) + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True) + Gb = nx.to_numpy(Gb) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + # Tests with kl-loss: + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True) + Gb = nx.to_numpy(Gb) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + def test_fgw2_gradients(): n_samples = 20 # nb samples @@ -617,6 +815,57 @@ def test_fgw2_gradients(): assert M1.shape == M1.grad.shape +def test_fgw_helper_backend(nx): + n_samples = 20 # nb samples + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + ys = np.random.randn(xs.shape[0], 2) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + yt = np.random.randn(xt.shape[0], 2) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + M /= M.max() + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + alpha = 0.5 + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True) + + # calls with nx=None + constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss') + + def f(G): + return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None) + + def df(G): + return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=(1 - alpha) * Mb, reg=alpha, nx=None) + # feed the precomputed local optimum Gb to cg + res, log = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=None) + # feed the precomputed local optimum Gb to cg + res_armijo, log_armijo = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) + np.testing.assert_allclose(res_armijo, Gb, atol=1e-06) + + def test_fgw_barycenter(nx): np.random.seed(42) @@ -1186,3 +1435,327 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): # > Compare results with/without backend total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2) np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05) + + +def test_semirelaxed_gromov(nx): + np.random.seed(0) + # unbalanced proportions + list_n = [30, 15] + nt = 2 + ns = np.sum(list_n) + # create directed sbm with C2 as connectivity matrix + C1 = np.zeros((ns, ns), dtype=np.float64) + C2 = np.array([[0.8, 0.05], + [0.05, 1.]], dtype=np.float64) + for i in range(nt): + for j in range(nt): + ni, nj = list_n[i], list_n[j] + xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j]) + C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + p = ot.unif(ns, type_as=C1) + q0 = ot.unif(C2.shape[0], type_as=C1) + G0 = p[:, None] * q0[None, :] + # asymmetric + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) + + G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) + Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=False, log=True, G0=None) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + + srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + + # symmetric + C1 = 0.5 * (C1 + C1.T) + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) + + G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None) + Gb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + + srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=False, G0=G0) + + G = log2['T'] + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + + np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + + +def test_semirelaxed_gromov2_gradients(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + # semirelaxed solvers do not support gradients over masses yet. + p1 = torch.tensor(p, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + + val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1) + + val.backward() + + assert val.device == p1.device + assert p1.grad is None + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + + +def test_srgw_helper_backend(nx): + n_samples = 20 # nb samples + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) + Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, 'square_loss', armijo=False, symmetric=True, G0=None, log=True) + + # calls with nx=None + constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss') + ones_pb = nx.ones(pb.shape[0], type_as=pb) + + def f(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_semirelaxed_gromov_linesearch( + G, deltaG, cost_G, C1b, C2b, ones_pb, 0., 1., nx=None) + # feed the precomputed local optimum Gb to semirelaxed_cg + res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) + + +def test_semirelaxed_fgw(nx): + np.random.seed(0) + list_n = [16, 8] + nt = 2 + ns = 24 + # create directed sbm with C2 as connectivity matrix + C1 = np.zeros((ns, ns)) + C2 = np.array([[0.7, 0.05], + [0.05, 0.9]]) + for i in range(nt): + for j in range(nt): + ni, nj = list_n[i], list_n[j] + xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j]) + C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + F1 = np.zeros((ns, 1)) + F1[:16] = np.random.normal(loc=0., scale=0.01, size=(16, 1)) + F1[16:] = np.random.normal(loc=1., scale=0.01, size=(8, 1)) + F2 = np.zeros((2, 1)) + F2[1, :] = 1. + M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T) + + p = ot.unif(ns) + q0 = ot.unif(C2.shape[0]) + G0 = p[:, None] * q0[None, :] + + # asymmetric + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + + # symmetric + C1 = 0.5 * (C1 + C1.T) + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + + srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + + +def test_semirelaxed_fgw2_gradients(): + n_samples = 20 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + M = ot.dist(xs, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + # semirelaxed solvers do not support gradients over masses yet. + p1 = torch.tensor(p, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + + val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1) + + val.backward() + + assert val.device == p1.device + assert p1.grad is None + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape + + +def test_srfgw_helper_backend(nx): + n_samples = 20 # nb samples + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + ys = np.random.randn(xs.shape[0], 2) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + yt = np.random.randn(xt.shape[0], 2) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + M /= M.max() + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + alpha = 0.5 + Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True) + + # calls with nx=None + constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss') + ones_pb = nx.ones(pb.shape[0], type_as=pb) + + def f(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_semirelaxed_gromov_linesearch( + G, deltaG, cost_G, C1b, C2b, ones_pb, M=(1 - alpha) * Mb, reg=alpha, nx=None) + # feed the precomputed local optimum Gb to semirelaxed_cg + res, log = ot.optim.semirelaxed_cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) diff --git a/test/test_optim.py b/test/test_optim.py index 67e9d13..129fe22 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -120,15 +120,15 @@ def test_generalized_conditional_gradient(nx): Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True) Gb = nx.to_numpy(Gb) - np.testing.assert_allclose(Gb, G) + np.testing.assert_allclose(Gb, G, atol=1e-12) np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05) np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05) def test_solve_1d_linesearch_quad_funct(): - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1), 0.5) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5), 0) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5), 1) def test_line_search_armijo(nx): -- cgit v1.2.3