From dc8737a30cb6d9f1305173eb8d16fe6716fd1231 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 30 Aug 2017 17:01:01 +0200 Subject: wroking make! --- docs/source/auto_examples/plot_barycenter_1D.py | 117 ++++++++++++------------ 1 file changed, 58 insertions(+), 59 deletions(-) (limited to 'docs/source/auto_examples/plot_barycenter_1D.py') diff --git a/docs/source/auto_examples/plot_barycenter_1D.py b/docs/source/auto_examples/plot_barycenter_1D.py index 30eecbf..875f44c 100644 --- a/docs/source/auto_examples/plot_barycenter_1D.py +++ b/docs/source/auto_examples/plot_barycenter_1D.py @@ -4,135 +4,134 @@ 1D Wasserstein barycenter demo ============================== - -@author: rflamary """ +# Author: Remi Flamary +# +# License: MIT License + import numpy as np import matplotlib.pylab as pl import ot -from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa from matplotlib.collections import PolyCollection #%% parameters -n=100 # nb bins +n = 100 # nb bins # bin positions -x=np.arange(n,dtype=np.float64) +x = np.arange(n, dtype=np.float64) # Gaussian distributions -a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std -a2=ot.datasets.get_1D_gauss(n,m=60,s=8) +a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.get_1D_gauss(n, m=60, s=8) # creating matrix A containing all distributions -A=np.vstack((a1,a2)).T -nbd=A.shape[1] +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] # loss matrix + normalization -M=ot.utils.dist0(n) -M/=M.max() +M = ot.utils.dist0(n) +M /= M.max() #%% plot the distributions -pl.figure(1) -for i in range(nbd): - pl.plot(x,A[:,i]) +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) pl.title('Distributions') +pl.tight_layout() #%% barycenter computation -alpha=0.2 # 0<=alpha<=1 -weights=np.array([1-alpha,alpha]) +alpha = 0.2 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) # l2bary -bary_l2=A.dot(weights) +bary_l2 = A.dot(weights) # wasserstein -reg=1e-3 -bary_wass=ot.bregman.barycenter(A,M,reg,weights) +reg = 1e-3 +bary_wass = ot.bregman.barycenter(A, M, reg, weights) pl.figure(2) pl.clf() -pl.subplot(2,1,1) -for i in range(nbd): - pl.plot(x,A[:,i]) +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) pl.title('Distributions') -pl.subplot(2,1,2) -pl.plot(x,bary_l2,'r',label='l2') -pl.plot(x,bary_wass,'g',label='Wasserstein') +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Wasserstein') pl.legend() pl.title('Barycenters') - +pl.tight_layout() #%% barycenter interpolation -nbalpha=11 -alphalist=np.linspace(0,1,nbalpha) +n_alpha = 11 +alpha_list = np.linspace(0, 1, n_alpha) -B_l2=np.zeros((n,nbalpha)) +B_l2 = np.zeros((n, n_alpha)) -B_wass=np.copy(B_l2) +B_wass = np.copy(B_l2) -for i in range(0,nbalpha): - alpha=alphalist[i] - weights=np.array([1-alpha,alpha]) - B_l2[:,i]=A.dot(weights) - B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights) +for i in range(0, n_alpha): + alpha = alpha_list[i] + weights = np.array([1 - alpha, alpha]) + B_l2[:, i] = A.dot(weights) + B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights) #%% plot interpolation -pl.figure(3,(10,5)) +pl.figure(3) -#pl.subplot(1,2,1) -cmap=pl.cm.get_cmap('viridis') +cmap = pl.cm.get_cmap('viridis') verts = [] -zs = alphalist -for i,z in enumerate(zs): - ys = B_l2[:,i] +zs = alpha_list +for i, z in enumerate(zs): + ys = B_l2[:, i] verts.append(list(zip(x, ys))) ax = pl.gcf().gca(projection='3d') -poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) +poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) ax.add_collection3d(poly, zs=zs, zdir='y') - ax.set_xlabel('x') ax.set_xlim3d(0, n) ax.set_ylabel('$\\alpha$') -ax.set_ylim3d(0,1) +ax.set_ylim3d(0, 1) ax.set_zlabel('') -ax.set_zlim3d(0, B_l2.max()*1.01) +ax.set_zlim3d(0, B_l2.max() * 1.01) pl.title('Barycenter interpolation with l2') +pl.tight_layout() -pl.show() - -pl.figure(4,(10,5)) - -#pl.subplot(1,2,1) -cmap=pl.cm.get_cmap('viridis') +pl.figure(4) +cmap = pl.cm.get_cmap('viridis') verts = [] -zs = alphalist -for i,z in enumerate(zs): - ys = B_wass[:,i] +zs = alpha_list +for i, z in enumerate(zs): + ys = B_wass[:, i] verts.append(list(zip(x, ys))) ax = pl.gcf().gca(projection='3d') -poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) +poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) ax.add_collection3d(poly, zs=zs, zdir='y') - ax.set_xlabel('x') ax.set_xlim3d(0, n) ax.set_ylabel('$\\alpha$') -ax.set_ylim3d(0,1) +ax.set_ylim3d(0, 1) ax.set_zlabel('') -ax.set_zlim3d(0, B_l2.max()*1.01) +ax.set_zlim3d(0, B_l2.max() * 1.01) pl.title('Barycenter interpolation with Wasserstein') +pl.tight_layout() -pl.show() \ No newline at end of file +pl.show() -- cgit v1.2.3