diff options
Diffstat (limited to 'examples/demo_barycenter_1D.py')
-rw-r--r-- | examples/demo_barycenter_1D.py | 86 |
1 files changed, 81 insertions, 5 deletions
diff --git a/examples/demo_barycenter_1D.py b/examples/demo_barycenter_1D.py index db1ff4a..7e478c2 100644 --- a/examples/demo_barycenter_1D.py +++ b/examples/demo_barycenter_1D.py @@ -9,7 +9,9 @@ import numpy as np import matplotlib.pylab as pl import ot - +from mpl_toolkits.mplot3d import Axes3D +from matplotlib.collections import PolyCollection +from matplotlib.colors import colorConverter #%% parameters @@ -19,8 +21,8 @@ n=100 # nb bins x=np.arange(n,dtype=np.float64) # Gaussian distributions -a1=ot.datasets.get_1D_gauss(n,m=20,s=20) # m= mean, s= std -a2=ot.datasets.get_1D_gauss(n,m=60,s=60) +a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std +a2=ot.datasets.get_1D_gauss(n,m=60,s=8) # creating matrix A containing all distributions A=np.vstack((a1,a2)).T @@ -39,12 +41,15 @@ pl.title('Distributions') #%% barycenter computation +alpha=0.2 # 0<=alpha<=1 +weights=np.array([1-alpha,alpha]) + # l2bary -bary_l2=A.mean(1) +bary_l2=A.dot(weights) # wasserstein reg=1e-3 -bary_wass=ot.bregman.barycenter(A,M,reg) +bary_wass=ot.bregman.barycenter(A,M,reg,weights) pl.figure(2) pl.clf() @@ -58,3 +63,74 @@ pl.plot(x,bary_l2,'r',label='l2') pl.plot(x,bary_wass,'g',label='Wasserstein') pl.legend() pl.title('Barycenters') + + +#%% barycenter interpolation + +nbalpha=11 +alphalist=np.linspace(0,1,nbalpha) + + +B_l2=np.zeros((n,nbalpha)) + +B_wass=np.copy(B_l2) + +for i in range(0,nbalpha): + alpha=alphalist[i] + weights=np.array([1-alpha,alpha]) + B_l2[:,i]=A.dot(weights) + B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights) + +#%% plot interpolation + +pl.figure(3,(10,5)) + +#pl.subplot(1,2,1) +cmap=pl.cm.get_cmap('viridis') +verts = [] +zs = alphalist +for i,z in enumerate(zs): + ys = B_l2[:,i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') + +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel('$\\alpha$') +ax.set_ylim3d(0,1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max()*1.01) +pl.title('Barycenter interpolation with l2') + +pl.show() + +pl.figure(4,(10,5)) + +#pl.subplot(1,2,1) +cmap=pl.cm.get_cmap('viridis') +verts = [] +zs = alphalist +for i,z in enumerate(zs): + ys = B_wass[:,i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') + +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel('$\\alpha$') +ax.set_ylim3d(0,1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max()*1.01) +pl.title('Barycenter interpolation with Wasserstein') + +pl.show()
\ No newline at end of file |