diff options
author | Eduardo Fernandes Montesuma <edumontesuma@gmail.com> | 2022-07-27 11:16:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-07-27 11:16:14 +0200 |
commit | 818c7ace20da36d8042b0d7ad7a712b27f7afd59 (patch) | |
tree | 58dd4e0c9f990ea0c851712d85748de99ce6b236 | |
parent | 7c2a9523747c90aebfef711fdf34b5bbdb6f2f4d (diff) |
[MRG] Free support Sinkhorn barycenters (#387)
* Adding function for computing Sinkhorn Free Support barycenters
* Adding exampel on Free Support Sinkhorn Barycenter
* Fixing typo on free support sinkhorn barycenter example
* Adding info on new Free Support Barycenter solver
* Removing extra line so that code follows pep8
* Fixing issues with pep8 in example
* Correcting issues with pep8 standards
* Adding tests for free support sinkhorn barycenter
* Adding section on Sinkhorn barycenter to the example
* Changing distributions for the Sinkhorn barycenter example
* Removing file that should not be on the last commit
* Adding PR number to REALEASES.md
* Adding new contributors
* Update CONTRIBUTORS.md
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r-- | CONTRIBUTORS.md | 1 | ||||
-rw-r--r-- | RELEASES.md | 1 | ||||
-rw-r--r-- | examples/barycenters/plot_free_support_barycenter.py | 28 | ||||
-rw-r--r-- | examples/barycenters/plot_free_support_sinkhorn_barycenter.py | 151 | ||||
-rw-r--r-- | ot/bregman.py | 120 | ||||
-rw-r--r-- | test/test_bregman.py | 26 |
6 files changed, 324 insertions, 3 deletions
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index c535c09..0524151 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -39,6 +39,7 @@ The contributors to this library are: * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) +* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) ## Acknowledgments diff --git a/RELEASES.md b/RELEASES.md index 78a7d9e..14d11c4 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### New features - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) +- Added Free Support Sinkhorn Barycenter + example (PR #387) #### Closed issues diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py index 226dfeb..f4a13dd 100644 --- a/examples/barycenters/plot_free_support_barycenter.py +++ b/examples/barycenters/plot_free_support_barycenter.py @@ -4,13 +4,14 @@ 2D free support Wasserstein barycenters of distributions ======================================================== -Illustration of 2D Wasserstein barycenters if distributions are weighted +Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted sum of diracs. """ # Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp> # Rémi Flamary <remi.flamary@polytechnique.edu> +# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr> # # License: MIT License @@ -48,7 +49,7 @@ pl.title('Distributions') # %% -# Compute free support barycenter +# Compute free support Wasserstein barycenter # ------------------------------- k = 200 # number of Diracs of the barycenter @@ -58,7 +59,28 @@ b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, on X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b) # %% -# Plot the barycenter +# Plot the Wasserstein barycenter +# --------- + +pl.figure(2, (8, 3)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) +pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter') +pl.title('Data measures and their barycenter') +pl.legend(loc="lower right") +pl.show() + +# %% +# Compute free support Sinkhorn barycenter + +k = 200 # number of Diracs of the barycenter +X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations +b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) + +X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15) + +# %% +# Plot the Wasserstein barycenter # --------- pl.figure(2, (8, 3)) diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py new file mode 100644 index 0000000..ebe1f3b --- /dev/null +++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +""" +======================================================== +2D free support Sinkhorn barycenters of distributions +======================================================== + +Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds + +""" + +# Authors: Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr> +# +# License: MIT License + +import numpy as np +import matplotlib.pyplot as plt +import ot + +# %% +# General Parameters +# ------------------ +reg = 1e-2 # Entropic Regularization +numItermax = 20 # Maximum number of iterations for the Barycenter algorithm +numInnerItermax = 50 # Maximum number of sinkhorn iterations +n_samples = 200 + +# %% +# Generate Data +# ------------- + +X1 = np.random.randn(200, 2) +X2 = 2 * np.concatenate([ + np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1), + np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1), +], axis=0) +X3 = np.random.randn(200, 2) +X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None]) +X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200) + +a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)) + +# %% +# Inspect generated distributions +# ------------------------------- + +fig, axes = plt.subplots(1, 4, figsize=(16, 4)) + +axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k') +axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k') +axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k') +axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k') + +axes[0].set_xlim([-3, 3]) +axes[0].set_ylim([-3, 3]) +axes[0].set_title('Distribution 1') + +axes[1].set_xlim([-3, 3]) +axes[1].set_ylim([-3, 3]) +axes[1].set_title('Distribution 2') + +axes[2].set_xlim([-3, 3]) +axes[2].set_ylim([-3, 3]) +axes[2].set_title('Distribution 3') + +axes[3].set_xlim([-3, 3]) +axes[3].set_ylim([-3, 3]) +axes[3].set_title('Distribution 4') + +plt.tight_layout() +plt.show() + +# %% +# Interpolating Empirical Distributions +# ------------------------------------- + +fig = plt.figure(figsize=(10, 10)) + +weights = np.array([ + [3 / 3, 0 / 3], + [2 / 3, 1 / 3], + [1 / 3, 2 / 3], + [0 / 3, 3 / 3], +]).astype(np.float32) + +for k in range(4): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X1, X2], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (0, k)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 4, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X1, X3], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (k, 0)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 4, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X3, X4], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (3, k)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 3, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X2, X4], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (k, 3)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +plt.show() diff --git a/ot/bregman.py b/ot/bregman.py index 34dcadb..b1321a4 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1540,6 +1540,126 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, return geometricBar(weights, UKv) +def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None, + numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None, + **kwargs): + r""" + Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally: + + .. math:: + \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) + + where : + + - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one + - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) + - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations + - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter + + This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). + There are two differences with the following codes: + + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in + :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete + implementation of the fixed-point algorithm of + :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting. + - at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the + transport plan in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). + + Parameters + ---------- + measures_locations : list of N (k_i,d) array-like + The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space + (:math:`k_i` can be different for each element of the list) + measures_weights : list of N (k_i,) array-like + Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one + representing the weights of each discrete input measure + + X_init : (k,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + reg : float + Regularization term >0 + b : (k,) array-like + Initialization of the weights of the barycenter (non-negatives, sum to 1) + weights : (N,) array-like + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterations when calculating the transport plans with Sinkhorn + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + X : (k,d) array-like + Support locations (on k atoms) of the barycenter + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT solver + ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming + + .. _references-free-support-barycenter: + References + ---------- + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + nx = get_backend(*measures_locations, *measures_weights, X_init) + + iter_count = 0 + + N = len(measures_locations) + k = X_init.shape[0] + d = X_init.shape[1] + if b is None: + b = nx.ones((k,), type_as=X_init) / k + if weights is None: + weights = nx.ones((N,), type_as=X_init) / N + + X = X_init + + log_dict = {} + displacement_square_norms = [] + + displacement_square_norm = stopThr + 1. + + while (displacement_square_norm > stopThr and iter_count < numItermax): + + T_sum = nx.zeros((k, d), type_as=X_init) + + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): + M_i = dist(X, measure_locations_i) + T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs) + T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i) + + displacement_square_norm = nx.sum((T_sum - X) ** 2) + if log: + displacement_square_norms.append(displacement_square_norm) + + X = T_sum + + if verbose: + print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) + + iter_count += 1 + + if log: + log_dict['displacement_square_norms'] = displacement_square_norms + return X, log_dict + else: + return X + + def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic wasserstein barycenter in log-domain diff --git a/test/test_bregman.py b/test/test_bregman.py index 112bfca..e128ea2 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -3,6 +3,7 @@ # Author: Remi Flamary <remi.flamary@unice.fr> # Kilian Fatras <kilian.fatras@irisa.fr> # Quang Huy Tran <quang-huy.tran@univ-ubs.fr> +# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr> # # License: MIT License @@ -490,6 +491,31 @@ def test_barycenter(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, log=True) +def test_free_support_sinkhorn_barycenter(): + measures_locations = [ + np.array([-1.]).reshape((1, 1)), # First dirac support + np.array([1.]).reshape((1, 1)) # Second dirac support + ] + + measures_weights = [ + np.array([1.]), # First dirac sample weights + np.array([1.]) # Second dirac sample weights + ] + + # Barycenter initialization + X_init = np.array([-12.]).reshape((1, 1)) + + # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter + bar_locations = np.array([0.]).reshape((1, 1)) + + # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization + # term to 1, but this should be, in general, fine-tuned to the problem. + X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1) + + # Verifies if calculated barycenter matches ground-truth + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + @pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], [True, False], [True, False])) |