From da8f6119484642eca6e8efb3e5aaecce7a777622 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 14 May 2018 11:34:59 +0200 Subject: add example --- examples/plot_barycenter_lp_vs_entropic.py | 284 +++++++++++++++++++++++++++++ 1 file changed, 284 insertions(+) create mode 100644 examples/plot_barycenter_lp_vs_entropic.py (limited to 'examples/plot_barycenter_lp_vs_entropic.py') diff --git a/examples/plot_barycenter_lp_vs_entropic.py b/examples/plot_barycenter_lp_vs_entropic.py new file mode 100644 index 0000000..2eded2f --- /dev/null +++ b/examples/plot_barycenter_lp_vs_entropic.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +""" +================================================================================= +1D Wasserstein barycenter comparison between exact LP and entropic regularization +================================================================================= + +This example illustrates the computation of regularized Wassersyein Barycenter +as proposed in [3]. + + +[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). +Iterative Bregman projections for regularized transportation problems +SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + +""" + +# Author: Remi Flamary +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa +from matplotlib.collections import PolyCollection # noqa + +#import ot.lp.cvx as cvx + +# +# Generate data +# ------------- + +#%% parameters + +problems = [] + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +# Gaussian distributions +a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.get_1D_gauss(n, m=60, s=8) + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + +# +# Plot data +# --------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +# +# Barycenter computation +# ---------------------- + +#%% barycenter computation + +alpha = 0.5 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +ot.tic() +bary_wass = ot.bregman.barycenter(A, M, reg, weights) +ot.toc() + + +ot.tic() +bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True) +ot.toc() + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') +pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + +problems.append([A, [bary_l2, bary_wass, bary_wass2]]) + +#%% parameters + +a1 = 1.0 * (x > 10) * (x < 50) +a2 = 1.0 * (x > 60) * (x < 80) + +a1 /= a1.sum() +a2 /= a2.sum() + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +# +# Barycenter computation +# ---------------------- + +#%% barycenter computation + +alpha = 0.5 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +ot.tic() +bary_wass = ot.bregman.barycenter(A, M, reg, weights) +ot.toc() + + +ot.tic() +bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True) +ot.toc() + + +problems.append([A, [bary_l2, bary_wass, bary_wass2]]) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') +pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + +#%% parameters + +a1 = np.zeros(n) +a2 = np.zeros(n) + +a1[10] = .25 +a1[20] = .5 +a1[30] = .25 +a2[80] = 1 + + +a1 /= a1.sum() +a2 /= a2.sum() + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +# +# Barycenter computation +# ---------------------- + +#%% barycenter computation + +alpha = 0.5 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +ot.tic() +bary_wass = ot.bregman.barycenter(A, M, reg, weights) +ot.toc() + + +ot.tic() +bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True) +ot.toc() + + +problems.append([A, [bary_l2, bary_wass, bary_wass2]]) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') +pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + + +# +# Final figure +# ------------ +# + +#%% plot + +nbm = len(problems) +nbm2 = (nbm // 2) + + +pl.figure(2, (20, 6)) +pl.clf() + +for i in range(nbm): + + A = problems[i][0] + bary_l2 = problems[i][1][0] + bary_wass = problems[i][1][1] + bary_wass2 = problems[i][1][2] + + pl.subplot(2, nbm, 1 + i) + for j in range(n_distributions): + pl.plot(x, A[:, j]) + if i == nbm2: + pl.title('Distributions') + pl.xticks(()) + pl.yticks(()) + + pl.subplot(2, nbm, 1 + i) + + pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)') + pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') + pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') + if i == nbm - 1: + pl.legend() + if i == nbm2: + pl.title('Barycenters') -- cgit v1.2.3