summaryrefslogtreecommitdiff
path: root/examples/barycenters
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati100@gmail.com>2021-11-03 08:41:35 +0100
committerGitHub <noreply@github.com>2021-11-03 08:41:35 +0100
commite1b67c641da3b3e497db6811af2c200022b10302 (patch)
tree44d42e1ae50d653bb07dd6ef9c1de14f71b21642 /examples/barycenters
parent61340d526702616ff000d9e1cf71f52dd199a103 (diff)
[WIP] Add debiased barycenter (Sinkhorn + convolutional sinkhorn) (#291)
* add debiased sinkhorn barycenter + make loops pythonic * add debiased arg in tests * add 1d and 2d examples of debiased barycenters * fix doctest * fix flake8 * pep8 + make func private + add convergence warnings * remove rel paths + add rng + pylab to pyplot * fix stopping criterion debiased * pass alex * change params with new API * add logdomain barycenters + separate debiased API * test new API * fix jax read-only ? * raise error for jax * test catch jax error * fix pytest catch error * fix relative path * fix flake8 * add warn arg everywhere * fix ref number * catch warnings in tests * add contrib to readme + change ref number * fix convolution example + gallery thumbnails * increase coverage * fix flake Co-authored-by: Hicham Janati <hicham.janati@inria.fr> Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'examples/barycenters')
-rw-r--r--examples/barycenters/plot_barycenter_1D.py63
-rw-r--r--examples/barycenters/plot_barycenter_lp_vs_entropic.py2
-rw-r--r--examples/barycenters/plot_convolutional_barycenter.py53
-rw-r--r--examples/barycenters/plot_debiased_barycenter.py131
4 files changed, 180 insertions, 69 deletions
diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py
index 63dc460..2373e99 100644
--- a/examples/barycenters/plot_barycenter_1D.py
+++ b/examples/barycenters/plot_barycenter_1D.py
@@ -18,10 +18,10 @@ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
#
# License: MIT License
-# sphinx_gallery_thumbnail_number = 4
+# sphinx_gallery_thumbnail_number = 1
import numpy as np
-import matplotlib.pylab as pl
+import matplotlib.pyplot as plt
import ot
# necessary for 3d plot even if not used
from mpl_toolkits.mplot3d import Axes3D # noqa
@@ -51,18 +51,6 @@ M = ot.utils.dist0(n)
M /= M.max()
##############################################################################
-# Plot data
-# ---------
-
-#%% plot the distributions
-
-pl.figure(1, figsize=(6.4, 3))
-for i in range(n_distributions):
- pl.plot(x, A[:, i])
-pl.title('Distributions')
-pl.tight_layout()
-
-##############################################################################
# Barycenter computation
# ----------------------
@@ -78,24 +66,20 @@ bary_l2 = A.dot(weights)
reg = 1e-3
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
-pl.figure(2)
-pl.clf()
-pl.subplot(2, 1, 1)
-for i in range(n_distributions):
- pl.plot(x, A[:, i])
-pl.title('Distributions')
+f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1)
+ax1.plot(x, A, color="black")
+ax1.set_title('Distributions')
-pl.subplot(2, 1, 2)
-pl.plot(x, bary_l2, 'r', label='l2')
-pl.plot(x, bary_wass, 'g', label='Wasserstein')
-pl.legend()
-pl.title('Barycenters')
-pl.tight_layout()
+ax2.plot(x, bary_l2, 'r', label='l2')
+ax2.plot(x, bary_wass, 'g', label='Wasserstein')
+ax2.set_title('Barycenters')
+
+plt.legend()
+plt.show()
##############################################################################
# Barycentric interpolation
# -------------------------
-
#%% barycenter interpolation
n_alpha = 11
@@ -106,24 +90,23 @@ B_l2 = np.zeros((n, n_alpha))
B_wass = np.copy(B_l2)
-for i in range(0, n_alpha):
+for i in range(n_alpha):
alpha = alpha_list[i]
weights = np.array([1 - alpha, alpha])
B_l2[:, i] = A.dot(weights)
B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
#%% plot interpolation
+plt.figure(2)
-pl.figure(3)
-
-cmap = pl.cm.get_cmap('viridis')
+cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = plt.gcf().gca(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -134,18 +117,18 @@ ax.set_ylabel('$\\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
-pl.title('Barycenter interpolation with l2')
-pl.tight_layout()
+plt.title('Barycenter interpolation with l2')
+plt.tight_layout()
-pl.figure(4)
-cmap = pl.cm.get_cmap('viridis')
+plt.figure(3)
+cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = plt.gcf().gca(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -156,7 +139,7 @@ ax.set_ylabel('$\\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
-pl.title('Barycenter interpolation with Wasserstein')
-pl.tight_layout()
+plt.title('Barycenter interpolation with Wasserstein')
+plt.tight_layout()
-pl.show()
+plt.show()
diff --git a/examples/barycenters/plot_barycenter_lp_vs_entropic.py b/examples/barycenters/plot_barycenter_lp_vs_entropic.py
index 57a6bac..6502f16 100644
--- a/examples/barycenters/plot_barycenter_lp_vs_entropic.py
+++ b/examples/barycenters/plot_barycenter_lp_vs_entropic.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
=================================================================================
-1D Wasserstein barycenter comparison between exact LP and entropic regularization
+1D Wasserstein barycenter: exact LP vs entropic regularization
=================================================================================
This example illustrates the computation of regularized Wasserstein Barycenter
diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py
index cbcd4a1..3721f31 100644
--- a/examples/barycenters/plot_convolutional_barycenter.py
+++ b/examples/barycenters/plot_convolutional_barycenter.py
@@ -6,17 +6,18 @@
Convolutional Wasserstein Barycenter example
============================================
-This example is designed to illustrate how the Convolutional Wasserstein Barycenter
-function of POT works.
+This example is designed to illustrate how the Convolutional Wasserstein
+Barycenter function of POT works.
"""
# Author: Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
-
+import os
+from pathlib import Path
import numpy as np
-import pylab as pl
+import matplotlib.pyplot as plt
import ot
##############################################################################
@@ -25,22 +26,19 @@ import ot
#
# The four distributions are constructed from 4 simple images
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
-f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2]
-f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2]
-f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2]
-f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2]
+f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2]
+f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2]
+f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
+f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
-A = []
f1 = f1 / np.sum(f1)
f2 = f2 / np.sum(f2)
f3 = f3 / np.sum(f3)
f4 = f4 / np.sum(f4)
-A.append(f1)
-A.append(f2)
-A.append(f3)
-A.append(f4)
-A = np.array(A)
+A = np.array([f1, f2, f3, f4])
nb_images = 5
@@ -57,14 +55,13 @@ v4 = np.array((0, 0, 0, 1))
# ----------------------------------------
#
-pl.figure(figsize=(10, 10))
-pl.title('Convolutional Wasserstein Barycenters in POT')
+fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7))
+plt.suptitle('Convolutional Wasserstein Barycenters in POT')
cm = 'Blues'
# regularization parameter
reg = 0.004
for i in range(nb_images):
for j in range(nb_images):
- pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
tx = float(i) / (nb_images - 1)
ty = float(j) / (nb_images - 1)
@@ -74,19 +71,19 @@ for i in range(nb_images):
weights = (1 - ty) * tmp1 + ty * tmp2
if i == 0 and j == 0:
- pl.imshow(f1, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f1, cmap=cm)
elif i == 0 and j == (nb_images - 1):
- pl.imshow(f3, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f3, cmap=cm)
elif i == (nb_images - 1) and j == 0:
- pl.imshow(f2, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f2, cmap=cm)
elif i == (nb_images - 1) and j == (nb_images - 1):
- pl.imshow(f4, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f4, cmap=cm)
else:
# call to barycenter computation
- pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
- pl.axis('off')
-pl.show()
+ axes[i, j].imshow(
+ ot.bregman.convolutional_barycenter2d(A, reg, weights),
+ cmap=cm
+ )
+ axes[i, j].axis('off')
+plt.tight_layout()
+plt.show()
diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py
new file mode 100644
index 0000000..2a603dd
--- /dev/null
+++ b/examples/barycenters/plot_debiased_barycenter.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+"""
+=================================
+Debiased Sinkhorn barycenter demo
+=================================
+
+This example illustrates the computation of the debiased Sinkhorn barycenter
+as proposed in [37]_.
+
+
+.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th
+ International Conference on Machine Learning, PMLR 119:4692-4701, 2020
+"""
+
+# Author: Hicham Janati <hicham.janati100@gmail.com>
+#
+# License: MIT License
+# sphinx_gallery_thumbnail_number = 3
+
+import os
+from pathlib import Path
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+import ot
+from ot.bregman import (barycenter, barycenter_debiased,
+ convolutional_barycenter2d,
+ convolutional_barycenter2d_debiased)
+
+##############################################################################
+# Debiased barycenter of 1D Gaussians
+# ------------------------------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+#%% barycenter computation
+
+alpha = 0.2 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+epsilons = [5e-3, 1e-2, 5e-2]
+
+
+bars = [barycenter(A, M, reg, weights) for reg in epsilons]
+bars_debiased = [barycenter_debiased(A, M, reg, weights) for reg in epsilons]
+labels = ["Sinkhorn barycenter", "Debiased barycenter"]
+colors = ["indianred", "gold"]
+
+f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True,
+ figsize=(12, 4), num=1)
+for ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased):
+ ax.plot(A[:, 0], color="k", ls="--", label="Input data", alpha=0.3)
+ ax.plot(A[:, 1], color="k", ls="--", alpha=0.3)
+ for data, label, color in zip([bar, bar_debiased], labels, colors):
+ ax.plot(data, color=color, label=label, lw=2)
+ ax.set_title(r"$\varepsilon = %.3f$" % eps)
+plt.legend()
+plt.show()
+
+
+##############################################################################
+# Debiased barycenter of 2D images
+# ---------------------------------
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+f1 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
+f2 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
+
+A = np.asarray([f1, f2]) + 1e-2
+A /= A.sum(axis=(1, 2))[:, None, None]
+
+##############################################################################
+# Display the input images
+
+fig, axes = plt.subplots(1, 2, figsize=(7, 4), num=2)
+for ax, img in zip(axes, A):
+ ax.imshow(img, cmap="Greys")
+ ax.axis("off")
+fig.tight_layout()
+plt.show()
+
+
+##############################################################################
+# Barycenter computation and visualization
+# ----------------------------------------
+#
+
+bars_sinkhorn, bars_debiased = [], []
+epsilons = [5e-3, 7e-3, 1e-2]
+for eps in epsilons:
+ bar = convolutional_barycenter2d(A, eps)
+ bar_debiased, log = convolutional_barycenter2d_debiased(A, eps, log=True)
+ bars_sinkhorn.append(bar)
+ bars_debiased.append(bar_debiased)
+
+titles = ["Sinkhorn", "Debiased"]
+all_bars = [bars_sinkhorn, bars_debiased]
+fig, axes = plt.subplots(2, 3, figsize=(8, 6), num=3)
+for jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)):
+ for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)):
+ ax.imshow(img, cmap="Greys")
+ if jj == 0:
+ ax.set_title(r"$\varepsilon = %.3f$" % eps, fontsize=13)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.spines['bottom'].set_visible(False)
+ ax.spines['left'].set_visible(False)
+ if ii == 0:
+ ax.set_ylabel(method, fontsize=15)
+fig.tight_layout()
+plt.show()