summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEduardo Fernandes Montesuma <edumontesuma@gmail.com>2022-07-27 11:16:14 +0200
committerGitHub <noreply@github.com>2022-07-27 11:16:14 +0200
commit818c7ace20da36d8042b0d7ad7a712b27f7afd59 (patch)
tree58dd4e0c9f990ea0c851712d85748de99ce6b236
parent7c2a9523747c90aebfef711fdf34b5bbdb6f2f4d (diff)
[MRG] Free support Sinkhorn barycenters (#387)
* Adding function for computing Sinkhorn Free Support barycenters * Adding exampel on Free Support Sinkhorn Barycenter * Fixing typo on free support sinkhorn barycenter example * Adding info on new Free Support Barycenter solver * Removing extra line so that code follows pep8 * Fixing issues with pep8 in example * Correcting issues with pep8 standards * Adding tests for free support sinkhorn barycenter * Adding section on Sinkhorn barycenter to the example * Changing distributions for the Sinkhorn barycenter example * Removing file that should not be on the last commit * Adding PR number to REALEASES.md * Adding new contributors * Update CONTRIBUTORS.md Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r--CONTRIBUTORS.md1
-rw-r--r--RELEASES.md1
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py28
-rw-r--r--examples/barycenters/plot_free_support_sinkhorn_barycenter.py151
-rw-r--r--ot/bregman.py120
-rw-r--r--test/test_bregman.py26
6 files changed, 324 insertions, 3 deletions
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index c535c09..0524151 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -39,6 +39,7 @@ The contributors to this library are:
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
+* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
## Acknowledgments
diff --git a/RELEASES.md b/RELEASES.md
index 78a7d9e..14d11c4 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -5,6 +5,7 @@
#### New features
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
+- Added Free Support Sinkhorn Barycenter + example (PR #387)
#### Closed issues
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 226dfeb..f4a13dd 100644
--- a/examples/barycenters/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -4,13 +4,14 @@
2D free support Wasserstein barycenters of distributions
========================================================
-Illustration of 2D Wasserstein barycenters if distributions are weighted
+Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted
sum of diracs.
"""
# Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
# Rémi Flamary <remi.flamary@polytechnique.edu>
+# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License
@@ -48,7 +49,7 @@ pl.title('Distributions')
# %%
-# Compute free support barycenter
+# Compute free support Wasserstein barycenter
# -------------------------------
k = 200 # number of Diracs of the barycenter
@@ -58,7 +59,28 @@ b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, on
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
# %%
-# Plot the barycenter
+# Plot the Wasserstein barycenter
+# ---------
+
+pl.figure(2, (8, 3))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
+pl.title('Data measures and their barycenter')
+pl.legend(loc="lower right")
+pl.show()
+
+# %%
+# Compute free support Sinkhorn barycenter
+
+k = 200 # number of Diracs of the barycenter
+X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
+b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
+
+X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15)
+
+# %%
+# Plot the Wasserstein barycenter
# ---------
pl.figure(2, (8, 3))
diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py
new file mode 100644
index 0000000..ebe1f3b
--- /dev/null
+++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py
@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+"""
+========================================================
+2D free support Sinkhorn barycenters of distributions
+========================================================
+
+Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds
+
+"""
+
+# Authors: Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pyplot as plt
+import ot
+
+# %%
+# General Parameters
+# ------------------
+reg = 1e-2 # Entropic Regularization
+numItermax = 20 # Maximum number of iterations for the Barycenter algorithm
+numInnerItermax = 50 # Maximum number of sinkhorn iterations
+n_samples = 200
+
+# %%
+# Generate Data
+# -------------
+
+X1 = np.random.randn(200, 2)
+X2 = 2 * np.concatenate([
+ np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1),
+ np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1),
+ np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1),
+ np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1),
+], axis=0)
+X3 = np.random.randn(200, 2)
+X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None])
+X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200)
+
+a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1))
+
+# %%
+# Inspect generated distributions
+# -------------------------------
+
+fig, axes = plt.subplots(1, 4, figsize=(16, 4))
+
+axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k')
+axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k')
+axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k')
+axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k')
+
+axes[0].set_xlim([-3, 3])
+axes[0].set_ylim([-3, 3])
+axes[0].set_title('Distribution 1')
+
+axes[1].set_xlim([-3, 3])
+axes[1].set_ylim([-3, 3])
+axes[1].set_title('Distribution 2')
+
+axes[2].set_xlim([-3, 3])
+axes[2].set_ylim([-3, 3])
+axes[2].set_title('Distribution 3')
+
+axes[3].set_xlim([-3, 3])
+axes[3].set_ylim([-3, 3])
+axes[3].set_title('Distribution 4')
+
+plt.tight_layout()
+plt.show()
+
+# %%
+# Interpolating Empirical Distributions
+# -------------------------------------
+
+fig = plt.figure(figsize=(10, 10))
+
+weights = np.array([
+ [3 / 3, 0 / 3],
+ [2 / 3, 1 / 3],
+ [1 / 3, 2 / 3],
+ [0 / 3, 3 / 3],
+]).astype(np.float32)
+
+for k in range(4):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X1, X2],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (0, k))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 4, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X1, X3],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (k, 0))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 4, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X3, X4],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (3, k))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 3, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X2, X4],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (k, 3))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+plt.show()
diff --git a/ot/bregman.py b/ot/bregman.py
index 34dcadb..b1321a4 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1540,6 +1540,126 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
+def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None,
+ numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None,
+ **kwargs):
+ r"""
+ Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally:
+
+ .. math::
+ \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)
+
+ where :
+
+ - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
+ - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
+ - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
+ - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
+
+ This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
+ There are two differences with the following codes:
+
+ - we do not optimize over the weights
+ - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
+ :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
+ implementation of the fixed-point algorithm of
+ :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
+ - at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the
+ transport plan in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
+
+ Parameters
+ ----------
+ measures_locations : list of N (k_i,d) array-like
+ The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
+ (:math:`k_i` can be different for each element of the list)
+ measures_weights : list of N (k_i,) array-like
+ Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
+ representing the weights of each discrete input measure
+
+ X_init : (k,d) array-like
+ Initialization of the support locations (on `k` atoms) of the barycenter
+ reg : float
+ Regularization term >0
+ b : (k,) array-like
+ Initialization of the weights of the barycenter (non-negatives, sum to 1)
+ weights : (N,) array-like
+ Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
+
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterations when calculating the transport plans with Sinkhorn
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ X : (k,d) array-like
+ Support locations (on k atoms) of the barycenter
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT solver
+ ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming
+
+ .. _references-free-support-barycenter:
+ References
+ ----------
+ .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+ """
+ nx = get_backend(*measures_locations, *measures_weights, X_init)
+
+ iter_count = 0
+
+ N = len(measures_locations)
+ k = X_init.shape[0]
+ d = X_init.shape[1]
+ if b is None:
+ b = nx.ones((k,), type_as=X_init) / k
+ if weights is None:
+ weights = nx.ones((N,), type_as=X_init) / N
+
+ X = X_init
+
+ log_dict = {}
+ displacement_square_norms = []
+
+ displacement_square_norm = stopThr + 1.
+
+ while (displacement_square_norm > stopThr and iter_count < numItermax):
+
+ T_sum = nx.zeros((k, d), type_as=X_init)
+
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
+ M_i = dist(X, measure_locations_i)
+ T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs)
+ T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i)
+
+ displacement_square_norm = nx.sum((T_sum - X) ** 2)
+ if log:
+ displacement_square_norms.append(displacement_square_norm)
+
+ X = T_sum
+
+ if verbose:
+ print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
+
+ iter_count += 1
+
+ if log:
+ log_dict['displacement_square_norms'] = displacement_square_norms
+ return X, log_dict
+ else:
+ return X
+
+
def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False, warn=True):
r"""Compute the entropic wasserstein barycenter in log-domain
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 112bfca..e128ea2 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -3,6 +3,7 @@
# Author: Remi Flamary <remi.flamary@unice.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License
@@ -490,6 +491,31 @@ def test_barycenter(nx, method, verbose, warn):
ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
+def test_free_support_sinkhorn_barycenter():
+ measures_locations = [
+ np.array([-1.]).reshape((1, 1)), # First dirac support
+ np.array([1.]).reshape((1, 1)) # Second dirac support
+ ]
+
+ measures_weights = [
+ np.array([1.]), # First dirac sample weights
+ np.array([1.]) # Second dirac sample weights
+ ]
+
+ # Barycenter initialization
+ X_init = np.array([-12.]).reshape((1, 1))
+
+ # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter
+ bar_locations = np.array([0.]).reshape((1, 1))
+
+ # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization
+ # term to 1, but this should be, in general, fine-tuned to the problem.
+ X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1)
+
+ # Verifies if calculated barycenter matches ground-truth
+ np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+
+
@pytest.mark.parametrize("method, verbose, warn",
product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
[True, False], [True, False]))