summaryrefslogtreecommitdiff
path: root/examples/barycenters
diff options
context:
space:
mode:
Diffstat (limited to 'examples/barycenters')
-rw-r--r--examples/barycenters/plot_barycenter_1D.py4
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py28
-rw-r--r--examples/barycenters/plot_free_support_sinkhorn_barycenter.py151
-rw-r--r--examples/barycenters/plot_generalized_free_support_barycenter.py155
4 files changed, 333 insertions, 5 deletions
diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py
index 2373e99..8096245 100644
--- a/examples/barycenters/plot_barycenter_1D.py
+++ b/examples/barycenters/plot_barycenter_1D.py
@@ -106,7 +106,7 @@ for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
-ax = plt.gcf().gca(projection='3d')
+ax = plt.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -128,7 +128,7 @@ for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
-ax = plt.gcf().gca(projection='3d')
+ax = plt.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 226dfeb..f4a13dd 100644
--- a/examples/barycenters/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -4,13 +4,14 @@
2D free support Wasserstein barycenters of distributions
========================================================
-Illustration of 2D Wasserstein barycenters if distributions are weighted
+Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted
sum of diracs.
"""
# Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
# RĂ©mi Flamary <remi.flamary@polytechnique.edu>
+# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License
@@ -48,7 +49,7 @@ pl.title('Distributions')
# %%
-# Compute free support barycenter
+# Compute free support Wasserstein barycenter
# -------------------------------
k = 200 # number of Diracs of the barycenter
@@ -58,7 +59,28 @@ b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, on
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
# %%
-# Plot the barycenter
+# Plot the Wasserstein barycenter
+# ---------
+
+pl.figure(2, (8, 3))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
+pl.title('Data measures and their barycenter')
+pl.legend(loc="lower right")
+pl.show()
+
+# %%
+# Compute free support Sinkhorn barycenter
+
+k = 200 # number of Diracs of the barycenter
+X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
+b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
+
+X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15)
+
+# %%
+# Plot the Wasserstein barycenter
# ---------
pl.figure(2, (8, 3))
diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py
new file mode 100644
index 0000000..ebe1f3b
--- /dev/null
+++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py
@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+"""
+========================================================
+2D free support Sinkhorn barycenters of distributions
+========================================================
+
+Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds
+
+"""
+
+# Authors: Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pyplot as plt
+import ot
+
+# %%
+# General Parameters
+# ------------------
+reg = 1e-2 # Entropic Regularization
+numItermax = 20 # Maximum number of iterations for the Barycenter algorithm
+numInnerItermax = 50 # Maximum number of sinkhorn iterations
+n_samples = 200
+
+# %%
+# Generate Data
+# -------------
+
+X1 = np.random.randn(200, 2)
+X2 = 2 * np.concatenate([
+ np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1),
+ np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1),
+ np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1),
+ np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1),
+], axis=0)
+X3 = np.random.randn(200, 2)
+X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None])
+X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200)
+
+a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1))
+
+# %%
+# Inspect generated distributions
+# -------------------------------
+
+fig, axes = plt.subplots(1, 4, figsize=(16, 4))
+
+axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k')
+axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k')
+axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k')
+axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k')
+
+axes[0].set_xlim([-3, 3])
+axes[0].set_ylim([-3, 3])
+axes[0].set_title('Distribution 1')
+
+axes[1].set_xlim([-3, 3])
+axes[1].set_ylim([-3, 3])
+axes[1].set_title('Distribution 2')
+
+axes[2].set_xlim([-3, 3])
+axes[2].set_ylim([-3, 3])
+axes[2].set_title('Distribution 3')
+
+axes[3].set_xlim([-3, 3])
+axes[3].set_ylim([-3, 3])
+axes[3].set_title('Distribution 4')
+
+plt.tight_layout()
+plt.show()
+
+# %%
+# Interpolating Empirical Distributions
+# -------------------------------------
+
+fig = plt.figure(figsize=(10, 10))
+
+weights = np.array([
+ [3 / 3, 0 / 3],
+ [2 / 3, 1 / 3],
+ [1 / 3, 2 / 3],
+ [0 / 3, 3 / 3],
+]).astype(np.float32)
+
+for k in range(4):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X1, X2],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (0, k))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 4, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X1, X3],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (k, 0))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 4, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X3, X4],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (3, k))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 3, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X2, X4],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (k, 3))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+plt.show()
diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py
new file mode 100644
index 0000000..e685ec7
--- /dev/null
+++ b/examples/barycenters/plot_generalized_free_support_barycenter.py
@@ -0,0 +1,155 @@
+# -*- coding: utf-8 -*-
+"""
+=======================================
+Generalized Wasserstein Barycenter Demo
+=======================================
+
+This example illustrates the computation of Generalized Wasserstein Barycenter
+as proposed in [42].
+
+
+[42] Delon, J., Gozlan, N., and Saint-Dizier, A..
+Generalized Wasserstein barycenters between probability measures living on different subspaces.
+arXiv preprint arXiv:2105.09755, 2021.
+
+"""
+
+# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.pylab as pl
+import ot
+import matplotlib.animation as animation
+
+########################
+# Generate and plot data
+# ----------------------
+
+# Input measures
+sub_sample_factor = 8
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
+I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2]
+I3 = pl.imread('../../data/heart.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2]
+
+sz = I1.shape[0]
+UU, VV = np.meshgrid(np.arange(sz), np.arange(sz))
+
+# Input measure locations in their respective 2D spaces
+X_list = [np.stack((UU[im == 0], VV[im == 0]), 1) * 1.0 for im in [I1, I2, I3]]
+
+# Input measure weights
+a_list = [ot.unif(x.shape[0]) for x in X_list]
+
+# Projections 3D -> 2D
+P1 = np.array([[1, 0, 0], [0, 1, 0]])
+P2 = np.array([[0, 1, 0], [0, 0, 1]])
+P3 = np.array([[1, 0, 0], [0, 0, 1]])
+P_list = [P1, P2, P3]
+
+# Barycenter weights
+weights = np.array([1 / 3, 1 / 3, 1 / 3])
+
+# Number of barycenter points to compute
+n_samples_bary = 150
+
+# Send the input measures into 3D space for visualisation
+X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)]
+
+# Plot the input data
+fig = plt.figure(figsize=(3, 3))
+axis = fig.add_subplot(1, 1, 1, projection="3d")
+for Xi in X_visu:
+ axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+axis.view_init(azim=45)
+axis.set_xticks([])
+axis.set_yticks([])
+axis.set_zticks([])
+plt.show()
+
+#################################
+# Barycenter computation and plot
+# -------------------------------
+
+Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary)
+fig = plt.figure(figsize=(3, 3))
+
+axis = fig.add_subplot(1, 1, 1, projection="3d")
+for Xi in X_visu:
+ axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+axis.view_init(azim=45)
+axis.set_xticks([])
+axis.set_yticks([])
+axis.set_zticks([])
+plt.show()
+
+
+#############################
+# Plotting projection matches
+# ---------------------------
+
+fig = plt.figure(figsize=(9, 3))
+
+ax = fig.add_subplot(1, 3, 1, projection='3d')
+for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ax.view_init(elev=0, azim=0)
+ax.set_xticks([])
+ax.set_yticks([])
+ax.set_zticks([])
+
+ax = fig.add_subplot(1, 3, 2, projection='3d')
+for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ax.view_init(elev=0, azim=90)
+ax.set_xticks([])
+ax.set_yticks([])
+ax.set_zticks([])
+
+ax = fig.add_subplot(1, 3, 3, projection='3d')
+for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ax.view_init(elev=90, azim=0)
+ax.set_xticks([])
+ax.set_yticks([])
+ax.set_zticks([])
+
+plt.tight_layout()
+plt.show()
+
+##############################################
+# Rotation animation
+# --------------------------------------------
+
+fig = plt.figure(figsize=(7, 7))
+ax = fig.add_subplot(1, 1, 1, projection="3d")
+
+
+def _init():
+ for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ ax.view_init(elev=0, azim=0)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.set_zticks([])
+ return fig,
+
+
+def _update_plot(i):
+ if i < 45:
+ ax.view_init(elev=0, azim=4 * i)
+ else:
+ ax.view_init(elev=i - 45, azim=4 * i)
+ return fig,
+
+
+ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=136, interval=50, blit=True, repeat_delay=2000)