summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorEduardo Fernandes Montesuma <edumontesuma@gmail.com>2022-07-27 11:16:14 +0200
committerGitHub <noreply@github.com>2022-07-27 11:16:14 +0200
commit818c7ace20da36d8042b0d7ad7a712b27f7afd59 (patch)
tree58dd4e0c9f990ea0c851712d85748de99ce6b236 /examples
parent7c2a9523747c90aebfef711fdf34b5bbdb6f2f4d (diff)
[MRG] Free support Sinkhorn barycenters (#387)
* Adding function for computing Sinkhorn Free Support barycenters * Adding exampel on Free Support Sinkhorn Barycenter * Fixing typo on free support sinkhorn barycenter example * Adding info on new Free Support Barycenter solver * Removing extra line so that code follows pep8 * Fixing issues with pep8 in example * Correcting issues with pep8 standards * Adding tests for free support sinkhorn barycenter * Adding section on Sinkhorn barycenter to the example * Changing distributions for the Sinkhorn barycenter example * Removing file that should not be on the last commit * Adding PR number to REALEASES.md * Adding new contributors * Update CONTRIBUTORS.md Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'examples')
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py28
-rw-r--r--examples/barycenters/plot_free_support_sinkhorn_barycenter.py151
2 files changed, 176 insertions, 3 deletions
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 226dfeb..f4a13dd 100644
--- a/examples/barycenters/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -4,13 +4,14 @@
2D free support Wasserstein barycenters of distributions
========================================================
-Illustration of 2D Wasserstein barycenters if distributions are weighted
+Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted
sum of diracs.
"""
# Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
# Rémi Flamary <remi.flamary@polytechnique.edu>
+# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License
@@ -48,7 +49,7 @@ pl.title('Distributions')
# %%
-# Compute free support barycenter
+# Compute free support Wasserstein barycenter
# -------------------------------
k = 200 # number of Diracs of the barycenter
@@ -58,7 +59,28 @@ b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, on
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
# %%
-# Plot the barycenter
+# Plot the Wasserstein barycenter
+# ---------
+
+pl.figure(2, (8, 3))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
+pl.title('Data measures and their barycenter')
+pl.legend(loc="lower right")
+pl.show()
+
+# %%
+# Compute free support Sinkhorn barycenter
+
+k = 200 # number of Diracs of the barycenter
+X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
+b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
+
+X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15)
+
+# %%
+# Plot the Wasserstein barycenter
# ---------
pl.figure(2, (8, 3))
diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py
new file mode 100644
index 0000000..ebe1f3b
--- /dev/null
+++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py
@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+"""
+========================================================
+2D free support Sinkhorn barycenters of distributions
+========================================================
+
+Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds
+
+"""
+
+# Authors: Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pyplot as plt
+import ot
+
+# %%
+# General Parameters
+# ------------------
+reg = 1e-2 # Entropic Regularization
+numItermax = 20 # Maximum number of iterations for the Barycenter algorithm
+numInnerItermax = 50 # Maximum number of sinkhorn iterations
+n_samples = 200
+
+# %%
+# Generate Data
+# -------------
+
+X1 = np.random.randn(200, 2)
+X2 = 2 * np.concatenate([
+ np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1),
+ np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1),
+ np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1),
+ np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1),
+], axis=0)
+X3 = np.random.randn(200, 2)
+X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None])
+X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200)
+
+a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1))
+
+# %%
+# Inspect generated distributions
+# -------------------------------
+
+fig, axes = plt.subplots(1, 4, figsize=(16, 4))
+
+axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k')
+axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k')
+axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k')
+axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k')
+
+axes[0].set_xlim([-3, 3])
+axes[0].set_ylim([-3, 3])
+axes[0].set_title('Distribution 1')
+
+axes[1].set_xlim([-3, 3])
+axes[1].set_ylim([-3, 3])
+axes[1].set_title('Distribution 2')
+
+axes[2].set_xlim([-3, 3])
+axes[2].set_ylim([-3, 3])
+axes[2].set_title('Distribution 3')
+
+axes[3].set_xlim([-3, 3])
+axes[3].set_ylim([-3, 3])
+axes[3].set_title('Distribution 4')
+
+plt.tight_layout()
+plt.show()
+
+# %%
+# Interpolating Empirical Distributions
+# -------------------------------------
+
+fig = plt.figure(figsize=(10, 10))
+
+weights = np.array([
+ [3 / 3, 0 / 3],
+ [2 / 3, 1 / 3],
+ [1 / 3, 2 / 3],
+ [0 / 3, 3 / 3],
+]).astype(np.float32)
+
+for k in range(4):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X1, X2],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (0, k))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 4, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X1, X3],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (k, 0))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 4, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X3, X4],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (3, k))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 3, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X2, X4],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (k, 3))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+plt.show()