summaryrefslogtreecommitdiff
path: root/examples/plot_barycenter_1D.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2016-12-02 13:26:23 +0100
committerGitHub <noreply@github.com>2016-12-02 13:26:23 +0100
commit98118c6be1832c2a1294c14d4f03c0eef08e68f9 (patch)
tree56c8a160fb6edcdaca8ca6ce6de1949b9bc33b77 /examples/plot_barycenter_1D.py
parent8dbfd3edae649f5f3e87be4a3ce446c59729b2f7 (diff)
parentf439f777084690ecbf54bcd8d67dadc883fffa31 (diff)
Merge pull request #3 from agramfort/sphx_gallery
first attempt to support sphinx-gallery
Diffstat (limited to 'examples/plot_barycenter_1D.py')
-rw-r--r--examples/plot_barycenter_1D.py135
1 files changed, 135 insertions, 0 deletions
diff --git a/examples/plot_barycenter_1D.py b/examples/plot_barycenter_1D.py
new file mode 100644
index 0000000..5466332
--- /dev/null
+++ b/examples/plot_barycenter_1D.py
@@ -0,0 +1,135 @@
+# -*- coding: utf-8 -*-
+"""
+1D Wasserstein barycenter demo
+
+@author: rflamary
+"""
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used
+from matplotlib.collections import PolyCollection
+
+
+#%% parameters
+
+n=100 # nb bins
+
+# bin positions
+x=np.arange(n,dtype=np.float64)
+
+# Gaussian distributions
+a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std
+a2=ot.datasets.get_1D_gauss(n,m=60,s=8)
+
+# creating matrix A containing all distributions
+A=np.vstack((a1,a2)).T
+nbd=A.shape[1]
+
+# loss matrix + normalization
+M=ot.utils.dist0(n)
+M/=M.max()
+
+#%% plot the distributions
+
+pl.figure(1)
+for i in range(nbd):
+ pl.plot(x,A[:,i])
+pl.title('Distributions')
+
+#%% barycenter computation
+
+alpha=0.2 # 0<=alpha<=1
+weights=np.array([1-alpha,alpha])
+
+# l2bary
+bary_l2=A.dot(weights)
+
+# wasserstein
+reg=1e-3
+bary_wass=ot.bregman.barycenter(A,M,reg,weights)
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2,1,1)
+for i in range(nbd):
+ pl.plot(x,A[:,i])
+pl.title('Distributions')
+
+pl.subplot(2,1,2)
+pl.plot(x,bary_l2,'r',label='l2')
+pl.plot(x,bary_wass,'g',label='Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+
+
+#%% barycenter interpolation
+
+nbalpha=11
+alphalist=np.linspace(0,1,nbalpha)
+
+
+B_l2=np.zeros((n,nbalpha))
+
+B_wass=np.copy(B_l2)
+
+for i in range(0,nbalpha):
+ alpha=alphalist[i]
+ weights=np.array([1-alpha,alpha])
+ B_l2[:,i]=A.dot(weights)
+ B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights)
+
+#%% plot interpolation
+
+pl.figure(3,(10,5))
+
+#pl.subplot(1,2,1)
+cmap=pl.cm.get_cmap('viridis')
+verts = []
+zs = alphalist
+for i,z in enumerate(zs):
+ ys = B_l2[:,i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel('$\\alpha$')
+ax.set_ylim3d(0,1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max()*1.01)
+pl.title('Barycenter interpolation with l2')
+
+pl.show()
+
+pl.figure(4,(10,5))
+
+#pl.subplot(1,2,1)
+cmap=pl.cm.get_cmap('viridis')
+verts = []
+zs = alphalist
+for i,z in enumerate(zs):
+ ys = B_wass[:,i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel('$\\alpha$')
+ax.set_ylim3d(0,1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max()*1.01)
+pl.title('Barycenter interpolation with Wasserstein')
+
+pl.show() \ No newline at end of file