summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_barycenter_1D.rst
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-12-02 15:38:59 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-12-02 15:38:59 +0100
commite458b7a58d9790e7c5ff40dea235402d9c4c8662 (patch)
treeac9da575654c78aa04a177723603935051b5d42d /docs/source/auto_examples/plot_barycenter_1D.rst
parent7609f9e6a4103e13beb294873f4dac562b1d45e1 (diff)
add doc for gallery
Diffstat (limited to 'docs/source/auto_examples/plot_barycenter_1D.rst')
-rw-r--r--docs/source/auto_examples/plot_barycenter_1D.rst193
1 files changed, 193 insertions, 0 deletions
diff --git a/docs/source/auto_examples/plot_barycenter_1D.rst b/docs/source/auto_examples/plot_barycenter_1D.rst
new file mode 100644
index 0000000..1b15c77
--- /dev/null
+++ b/docs/source/auto_examples/plot_barycenter_1D.rst
@@ -0,0 +1,193 @@
+
+
+.. _sphx_glr_auto_examples_plot_barycenter_1D.py:
+
+
+==============================
+1D Wasserstein barycenter demo
+==============================
+
+
+@author: rflamary
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_001.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_002.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_003.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_004.png
+ :scale: 47
+
+
+
+
+
+.. code-block:: python
+
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used
+ from matplotlib.collections import PolyCollection
+
+
+ #%% parameters
+
+ n=100 # nb bins
+
+ # bin positions
+ x=np.arange(n,dtype=np.float64)
+
+ # Gaussian distributions
+ a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std
+ a2=ot.datasets.get_1D_gauss(n,m=60,s=8)
+
+ # creating matrix A containing all distributions
+ A=np.vstack((a1,a2)).T
+ nbd=A.shape[1]
+
+ # loss matrix + normalization
+ M=ot.utils.dist0(n)
+ M/=M.max()
+
+ #%% plot the distributions
+
+ pl.figure(1)
+ for i in range(nbd):
+ pl.plot(x,A[:,i])
+ pl.title('Distributions')
+
+ #%% barycenter computation
+
+ alpha=0.2 # 0<=alpha<=1
+ weights=np.array([1-alpha,alpha])
+
+ # l2bary
+ bary_l2=A.dot(weights)
+
+ # wasserstein
+ reg=1e-3
+ bary_wass=ot.bregman.barycenter(A,M,reg,weights)
+
+ pl.figure(2)
+ pl.clf()
+ pl.subplot(2,1,1)
+ for i in range(nbd):
+ pl.plot(x,A[:,i])
+ pl.title('Distributions')
+
+ pl.subplot(2,1,2)
+ pl.plot(x,bary_l2,'r',label='l2')
+ pl.plot(x,bary_wass,'g',label='Wasserstein')
+ pl.legend()
+ pl.title('Barycenters')
+
+
+ #%% barycenter interpolation
+
+ nbalpha=11
+ alphalist=np.linspace(0,1,nbalpha)
+
+
+ B_l2=np.zeros((n,nbalpha))
+
+ B_wass=np.copy(B_l2)
+
+ for i in range(0,nbalpha):
+ alpha=alphalist[i]
+ weights=np.array([1-alpha,alpha])
+ B_l2[:,i]=A.dot(weights)
+ B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights)
+
+ #%% plot interpolation
+
+ pl.figure(3,(10,5))
+
+ #pl.subplot(1,2,1)
+ cmap=pl.cm.get_cmap('viridis')
+ verts = []
+ zs = alphalist
+ for i,z in enumerate(zs):
+ ys = B_l2[:,i]
+ verts.append(list(zip(x, ys)))
+
+ ax = pl.gcf().gca(projection='3d')
+
+ poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
+ poly.set_alpha(0.7)
+ ax.add_collection3d(poly, zs=zs, zdir='y')
+
+ ax.set_xlabel('x')
+ ax.set_xlim3d(0, n)
+ ax.set_ylabel('$\\alpha$')
+ ax.set_ylim3d(0,1)
+ ax.set_zlabel('')
+ ax.set_zlim3d(0, B_l2.max()*1.01)
+ pl.title('Barycenter interpolation with l2')
+
+ pl.show()
+
+ pl.figure(4,(10,5))
+
+ #pl.subplot(1,2,1)
+ cmap=pl.cm.get_cmap('viridis')
+ verts = []
+ zs = alphalist
+ for i,z in enumerate(zs):
+ ys = B_wass[:,i]
+ verts.append(list(zip(x, ys)))
+
+ ax = pl.gcf().gca(projection='3d')
+
+ poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
+ poly.set_alpha(0.7)
+ ax.add_collection3d(poly, zs=zs, zdir='y')
+
+ ax.set_xlabel('x')
+ ax.set_xlim3d(0, n)
+ ax.set_ylabel('$\\alpha$')
+ ax.set_ylim3d(0,1)
+ ax.set_zlabel('')
+ ax.set_zlim3d(0, B_l2.max()*1.01)
+ pl.title('Barycenter interpolation with Wasserstein')
+
+ pl.show()
+**Total running time of the script:** ( 0 minutes 2.274 seconds)
+
+
+
+.. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_barycenter_1D.py <plot_barycenter_1D.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_barycenter_1D.ipynb <plot_barycenter_1D.ipynb>`
+
+.. rst-class:: sphx-glr-signature
+
+ `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_