diff options
author | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2017-07-12 22:11:25 +0200 |
---|---|---|
committer | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2017-07-20 14:05:12 +0200 |
commit | 75c988f515f0a1ee51f88f5fc429a1301a1ca8c5 (patch) | |
tree | 3ec979dbacdf91c2b52fa5892f1080057c33e05e /examples/plot_barycenter_1D.py | |
parent | c6cb1cd666a3e1b761b83a6e0f9339268e69f099 (diff) |
do plot_barycenter_1D
Diffstat (limited to 'examples/plot_barycenter_1D.py')
-rw-r--r-- | examples/plot_barycenter_1D.py | 111 |
1 files changed, 54 insertions, 57 deletions
diff --git a/examples/plot_barycenter_1D.py b/examples/plot_barycenter_1D.py index 30eecbf..ab236e1 100644 --- a/examples/plot_barycenter_1D.py +++ b/examples/plot_barycenter_1D.py @@ -11,128 +11,125 @@ import numpy as np import matplotlib.pylab as pl import ot -from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa from matplotlib.collections import PolyCollection #%% parameters -n=100 # nb bins +n = 100 # nb bins # bin positions -x=np.arange(n,dtype=np.float64) +x = np.arange(n, dtype=np.float64) # Gaussian distributions -a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std -a2=ot.datasets.get_1D_gauss(n,m=60,s=8) +a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.get_1D_gauss(n, m=60, s=8) # creating matrix A containing all distributions -A=np.vstack((a1,a2)).T -nbd=A.shape[1] +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] # loss matrix + normalization -M=ot.utils.dist0(n) -M/=M.max() +M = ot.utils.dist0(n) +M /= M.max() #%% plot the distributions -pl.figure(1) -for i in range(nbd): - pl.plot(x,A[:,i]) +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) pl.title('Distributions') +pl.tight_layout() #%% barycenter computation -alpha=0.2 # 0<=alpha<=1 -weights=np.array([1-alpha,alpha]) +alpha = 0.2 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) # l2bary -bary_l2=A.dot(weights) +bary_l2 = A.dot(weights) # wasserstein -reg=1e-3 -bary_wass=ot.bregman.barycenter(A,M,reg,weights) +reg = 1e-3 +bary_wass = ot.bregman.barycenter(A, M, reg, weights) pl.figure(2) pl.clf() -pl.subplot(2,1,1) -for i in range(nbd): - pl.plot(x,A[:,i]) +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) pl.title('Distributions') -pl.subplot(2,1,2) -pl.plot(x,bary_l2,'r',label='l2') -pl.plot(x,bary_wass,'g',label='Wasserstein') +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Wasserstein') pl.legend() pl.title('Barycenters') - +pl.tight_layout() #%% barycenter interpolation -nbalpha=11 -alphalist=np.linspace(0,1,nbalpha) +n_alpha = 11 +alpha_list = np.linspace(0, 1, n_alpha) -B_l2=np.zeros((n,nbalpha)) +B_l2 = np.zeros((n, n_alpha)) -B_wass=np.copy(B_l2) +B_wass = np.copy(B_l2) -for i in range(0,nbalpha): - alpha=alphalist[i] - weights=np.array([1-alpha,alpha]) - B_l2[:,i]=A.dot(weights) - B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights) +for i in range(0, n_alpha): + alpha = alpha_list[i] + weights = np.array([1 - alpha, alpha]) + B_l2[:, i] = A.dot(weights) + B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights) #%% plot interpolation -pl.figure(3,(10,5)) +pl.figure(3) -#pl.subplot(1,2,1) -cmap=pl.cm.get_cmap('viridis') +cmap = pl.cm.get_cmap('viridis') verts = [] -zs = alphalist -for i,z in enumerate(zs): - ys = B_l2[:,i] +zs = alpha_list +for i, z in enumerate(zs): + ys = B_l2[:, i] verts.append(list(zip(x, ys))) ax = pl.gcf().gca(projection='3d') -poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) +poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) ax.add_collection3d(poly, zs=zs, zdir='y') - ax.set_xlabel('x') ax.set_xlim3d(0, n) ax.set_ylabel('$\\alpha$') -ax.set_ylim3d(0,1) +ax.set_ylim3d(0, 1) ax.set_zlabel('') -ax.set_zlim3d(0, B_l2.max()*1.01) +ax.set_zlim3d(0, B_l2.max() * 1.01) pl.title('Barycenter interpolation with l2') +pl.tight_layout() -pl.show() - -pl.figure(4,(10,5)) - -#pl.subplot(1,2,1) -cmap=pl.cm.get_cmap('viridis') +pl.figure(4) +cmap = pl.cm.get_cmap('viridis') verts = [] -zs = alphalist -for i,z in enumerate(zs): - ys = B_wass[:,i] +zs = alpha_list +for i, z in enumerate(zs): + ys = B_wass[:, i] verts.append(list(zip(x, ys))) ax = pl.gcf().gca(projection='3d') -poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) +poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) ax.add_collection3d(poly, zs=zs, zdir='y') - ax.set_xlabel('x') ax.set_xlim3d(0, n) ax.set_ylabel('$\\alpha$') -ax.set_ylim3d(0,1) +ax.set_ylim3d(0, 1) ax.set_zlabel('') -ax.set_zlim3d(0, B_l2.max()*1.01) +ax.set_zlim3d(0, B_l2.max() * 1.01) pl.title('Barycenter interpolation with Wasserstein') +pl.tight_layout() -pl.show()
\ No newline at end of file +pl.show() |