diff options
Diffstat (limited to 'examples/barycenters')
-rw-r--r-- | examples/barycenters/README.txt | 4 | ||||
-rw-r--r-- | examples/barycenters/plot_barycenter_1D.py | 162 | ||||
-rw-r--r-- | examples/barycenters/plot_barycenter_lp_vs_entropic.py | 288 | ||||
-rw-r--r-- | examples/barycenters/plot_convolutional_barycenter.py | 92 | ||||
-rw-r--r-- | examples/barycenters/plot_free_support_barycenter.py | 69 |
5 files changed, 615 insertions, 0 deletions
diff --git a/examples/barycenters/README.txt b/examples/barycenters/README.txt new file mode 100644 index 0000000..8461f7f --- /dev/null +++ b/examples/barycenters/README.txt @@ -0,0 +1,4 @@ + + +Wasserstein barycenters +-----------------------
\ No newline at end of file diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py new file mode 100644 index 0000000..63dc460 --- /dev/null +++ b/examples/barycenters/plot_barycenter_1D.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +""" +============================== +1D Wasserstein barycenter demo +============================== + +This example illustrates the computation of regularized Wassersyein Barycenter +as proposed in [3]. + + +[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). +Iterative Bregman projections for regularized transportation problems +SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +import ot +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa +from matplotlib.collections import PolyCollection + +############################################################################## +# Generate data +# ------------- + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +############################################################################## +# Barycenter computation +# ---------------------- + +#%% barycenter computation + +alpha = 0.2 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +bary_wass = ot.bregman.barycenter(A, M, reg, weights) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + +############################################################################## +# Barycentric interpolation +# ------------------------- + +#%% barycenter interpolation + +n_alpha = 11 +alpha_list = np.linspace(0, 1, n_alpha) + + +B_l2 = np.zeros((n, n_alpha)) + +B_wass = np.copy(B_l2) + +for i in range(0, n_alpha): + alpha = alpha_list[i] + weights = np.array([1 - alpha, alpha]) + B_l2[:, i] = A.dot(weights) + B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights) + +#%% plot interpolation + +pl.figure(3) + +cmap = pl.cm.get_cmap('viridis') +verts = [] +zs = alpha_list +for i, z in enumerate(zs): + ys = B_l2[:, i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel('$\\alpha$') +ax.set_ylim3d(0, 1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max() * 1.01) +pl.title('Barycenter interpolation with l2') +pl.tight_layout() + +pl.figure(4) +cmap = pl.cm.get_cmap('viridis') +verts = [] +zs = alpha_list +for i, z in enumerate(zs): + ys = B_wass[:, i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel('$\\alpha$') +ax.set_ylim3d(0, 1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max() * 1.01) +pl.title('Barycenter interpolation with Wasserstein') +pl.tight_layout() + +pl.show() diff --git a/examples/barycenters/plot_barycenter_lp_vs_entropic.py b/examples/barycenters/plot_barycenter_lp_vs_entropic.py new file mode 100644 index 0000000..57a6bac --- /dev/null +++ b/examples/barycenters/plot_barycenter_lp_vs_entropic.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- +""" +================================================================================= +1D Wasserstein barycenter comparison between exact LP and entropic regularization +================================================================================= + +This example illustrates the computation of regularized Wasserstein Barycenter +as proposed in [3] and exact LP barycenters using standard LP solver. + +It reproduces approximately Figure 3.1 and 3.2 from the following paper: +Cuturi, M., & Peyré, G. (2016). A smoothed dual approach for variational +Wasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343. + +[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). +Iterative Bregman projections for regularized transportation problems +SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +import ot +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa +from matplotlib.collections import PolyCollection # noqa + +#import ot.lp.cvx as cvx + +############################################################################## +# Gaussian Data +# ------------- + +#%% parameters + +problems = [] + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +# Gaussian distributions +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +#%% barycenter computation + +alpha = 0.5 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +ot.tic() +bary_wass = ot.bregman.barycenter(A, M, reg, weights) +ot.toc() + + +ot.tic() +bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True) +ot.toc() + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') +pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + +problems.append([A, [bary_l2, bary_wass, bary_wass2]]) + +############################################################################## +# Stair Data +# ---------- + +#%% parameters + +a1 = 1.0 * (x > 10) * (x < 50) +a2 = 1.0 * (x > 60) * (x < 80) + +a1 /= a1.sum() +a2 /= a2.sum() + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + + +#%% barycenter computation + +alpha = 0.5 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +ot.tic() +bary_wass = ot.bregman.barycenter(A, M, reg, weights) +ot.toc() + + +ot.tic() +bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True) +ot.toc() + + +problems.append([A, [bary_l2, bary_wass, bary_wass2]]) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') +pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + + +############################################################################## +# Dirac Data +# ---------- + +#%% parameters + +a1 = np.zeros(n) +a2 = np.zeros(n) + +a1[10] = .25 +a1[20] = .5 +a1[30] = .25 +a2[80] = 1 + + +a1 /= a1.sum() +a2 /= a2.sum() + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + + +#%% barycenter computation + +alpha = 0.5 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +ot.tic() +bary_wass = ot.bregman.barycenter(A, M, reg, weights) +ot.toc() + + +ot.tic() +bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True) +ot.toc() + + +problems.append([A, [bary_l2, bary_wass, bary_wass2]]) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') +pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + + +############################################################################## +# Final figure +# ------------ +# + +#%% plot + +nbm = len(problems) +nbm2 = (nbm // 2) + + +pl.figure(2, (20, 6)) +pl.clf() + +for i in range(nbm): + + A = problems[i][0] + bary_l2 = problems[i][1][0] + bary_wass = problems[i][1][1] + bary_wass2 = problems[i][1][2] + + pl.subplot(2, nbm, 1 + i) + for j in range(n_distributions): + pl.plot(x, A[:, j]) + if i == nbm2: + pl.title('Distributions') + pl.xticks(()) + pl.yticks(()) + + pl.subplot(2, nbm, 1 + i + nbm) + + pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)') + pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') + pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') + if i == nbm - 1: + pl.legend() + if i == nbm2: + pl.title('Barycenters') + + pl.xticks(()) + pl.yticks(()) diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py new file mode 100644 index 0000000..e74db04 --- /dev/null +++ b/examples/barycenters/plot_convolutional_barycenter.py @@ -0,0 +1,92 @@ + +#%% +# -*- coding: utf-8 -*- +""" +============================================ +Convolutional Wasserstein Barycenter example +============================================ + +This example is designed to illustrate how the Convolutional Wasserstein Barycenter +function of POT works. +""" + +# Author: Nicolas Courty <ncourty@irisa.fr> +# +# License: MIT License + + +import numpy as np +import pylab as pl +import ot + +############################################################################## +# Data preparation +# ---------------- +# +# The four distributions are constructed from 4 simple images + + +f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2] +f2 = 1 - pl.imread('../data/duck.png')[:, :, 2] +f3 = 1 - pl.imread('../data/heart.png')[:, :, 2] +f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2] + +A = [] +f1 = f1 / np.sum(f1) +f2 = f2 / np.sum(f2) +f3 = f3 / np.sum(f3) +f4 = f4 / np.sum(f4) +A.append(f1) +A.append(f2) +A.append(f3) +A.append(f4) +A = np.array(A) + +nb_images = 5 + +# those are the four corners coordinates that will be interpolated by bilinear +# interpolation +v1 = np.array((1, 0, 0, 0)) +v2 = np.array((0, 1, 0, 0)) +v3 = np.array((0, 0, 1, 0)) +v4 = np.array((0, 0, 0, 1)) + + +############################################################################## +# Barycenter computation and visualization +# ---------------------------------------- +# + +pl.figure(figsize=(10, 10)) +pl.title('Convolutional Wasserstein Barycenters in POT') +cm = 'Blues' +# regularization parameter +reg = 0.004 +for i in range(nb_images): + for j in range(nb_images): + pl.subplot(nb_images, nb_images, i * nb_images + j + 1) + tx = float(i) / (nb_images - 1) + ty = float(j) / (nb_images - 1) + + # weights are constructed by bilinear interpolation + tmp1 = (1 - tx) * v1 + tx * v2 + tmp2 = (1 - tx) * v3 + tx * v4 + weights = (1 - ty) * tmp1 + ty * tmp2 + + if i == 0 and j == 0: + pl.imshow(f1, cmap=cm) + pl.axis('off') + elif i == 0 and j == (nb_images - 1): + pl.imshow(f3, cmap=cm) + pl.axis('off') + elif i == (nb_images - 1) and j == 0: + pl.imshow(f2, cmap=cm) + pl.axis('off') + elif i == (nb_images - 1) and j == (nb_images - 1): + pl.imshow(f4, cmap=cm) + pl.axis('off') + else: + # call to barycenter computation + pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm) + pl.axis('off') +pl.show() diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py new file mode 100644 index 0000000..64b89e4 --- /dev/null +++ b/examples/barycenters/plot_free_support_barycenter.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +2D free support Wasserstein barycenters of distributions +==================================================== + +Illustration of 2D Wasserstein barycenters if discributions that are weighted +sum of diracs. + +""" + +# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot + + +############################################################################## +# Generate data +# ------------- +#%% parameters and data generation +N = 3 +d = 2 +measures_locations = [] +measures_weights = [] + +for i in range(N): + + n_i = np.random.randint(low=1, high=20) # nb samples + + mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean + + A_i = np.random.rand(d, d) + cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix + + x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations + b_i = np.random.uniform(0., 1., (n_i,)) + b_i = b_i / np.sum(b_i) # Dirac weights + + measures_locations.append(x_i) + measures_weights.append(b_i) + + +############################################################################## +# Compute free support barycenter +# ------------- + +k = 10 # number of Diracs of the barycenter +X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations +b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) + +X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b) + + +############################################################################## +# Plot data +# --------- + +pl.figure(1) +for (x_i, b_i) in zip(measures_locations, measures_weights): + color = np.random.randint(low=1, high=10 * N) + pl.scatter(x_i[:, 0], x_i[:, 1], s=b_i * 1000, label='input measure') +pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter') +pl.title('Data measures and their barycenter') +pl.legend(loc=0) +pl.show() |