summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_barycenter_1D.rst
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-09-15 14:54:21 +0200
committerGitHub <noreply@github.com>2017-09-15 14:54:21 +0200
commit81b2796226f3abde29fc024752728444da77509a (patch)
treec52cec3c38552f9f8c15361758aa9a80c30c3ef3 /docs/source/auto_examples/plot_barycenter_1D.rst
parente70d5420204db78691af2d0fbe04cc3d4416a8f4 (diff)
parent7fea2cd3e8ad29bf3fa442d7642bae124ee2bab0 (diff)
Merge pull request #27 from rflamary/autonb
auto notebooks + release update (fixes #16)
Diffstat (limited to 'docs/source/auto_examples/plot_barycenter_1D.rst')
-rw-r--r--docs/source/auto_examples/plot_barycenter_1D.rst211
1 files changed, 135 insertions, 76 deletions
diff --git a/docs/source/auto_examples/plot_barycenter_1D.rst b/docs/source/auto_examples/plot_barycenter_1D.rst
index 1b15c77..f17f2c2 100644
--- a/docs/source/auto_examples/plot_barycenter_1D.rst
+++ b/docs/source/auto_examples/plot_barycenter_1D.rst
@@ -7,171 +7,230 @@
1D Wasserstein barycenter demo
==============================
+This example illustrates the computation of regularized Wassersyein Barycenter
+as proposed in [3].
-@author: rflamary
+[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+Iterative Bregman projections for regularized transportation problems
+SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
-.. rst-class:: sphx-glr-horizontal
+.. code-block:: python
- *
- .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_001.png
- :scale: 47
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
- *
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ # necessary for 3d plot even if not used
+ from mpl_toolkits.mplot3d import Axes3D # noqa
+ from matplotlib.collections import PolyCollection
- .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_002.png
- :scale: 47
- *
- .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_003.png
- :scale: 47
- *
- .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_004.png
- :scale: 47
+Generate data
+-------------
.. code-block:: python
- import numpy as np
- import matplotlib.pylab as pl
- import ot
- from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used
- from matplotlib.collections import PolyCollection
-
-
#%% parameters
- n=100 # nb bins
+ n = 100 # nb bins
# bin positions
- x=np.arange(n,dtype=np.float64)
+ x = np.arange(n, dtype=np.float64)
# Gaussian distributions
- a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std
- a2=ot.datasets.get_1D_gauss(n,m=60,s=8)
+ a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ a2 = ot.datasets.get_1D_gauss(n, m=60, s=8)
# creating matrix A containing all distributions
- A=np.vstack((a1,a2)).T
- nbd=A.shape[1]
+ A = np.vstack((a1, a2)).T
+ n_distributions = A.shape[1]
# loss matrix + normalization
- M=ot.utils.dist0(n)
- M/=M.max()
+ M = ot.utils.dist0(n)
+ M /= M.max()
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
#%% plot the distributions
- pl.figure(1)
- for i in range(nbd):
- pl.plot(x,A[:,i])
+ pl.figure(1, figsize=(6.4, 3))
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
pl.title('Distributions')
+ pl.tight_layout()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_001.png
+ :align: center
+
+
+
+
+Barycenter computation
+----------------------
+
+
+
+.. code-block:: python
+
#%% barycenter computation
- alpha=0.2 # 0<=alpha<=1
- weights=np.array([1-alpha,alpha])
+ alpha = 0.2 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
# l2bary
- bary_l2=A.dot(weights)
+ bary_l2 = A.dot(weights)
# wasserstein
- reg=1e-3
- bary_wass=ot.bregman.barycenter(A,M,reg,weights)
+ reg = 1e-3
+ bary_wass = ot.bregman.barycenter(A, M, reg, weights)
pl.figure(2)
pl.clf()
- pl.subplot(2,1,1)
- for i in range(nbd):
- pl.plot(x,A[:,i])
+ pl.subplot(2, 1, 1)
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
pl.title('Distributions')
- pl.subplot(2,1,2)
- pl.plot(x,bary_l2,'r',label='l2')
- pl.plot(x,bary_wass,'g',label='Wasserstein')
+ pl.subplot(2, 1, 2)
+ pl.plot(x, bary_l2, 'r', label='l2')
+ pl.plot(x, bary_wass, 'g', label='Wasserstein')
pl.legend()
pl.title('Barycenters')
+ pl.tight_layout()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_003.png
+ :align: center
+
+
+
+
+Barycentric interpolation
+-------------------------
+
+
+
+.. code-block:: python
#%% barycenter interpolation
- nbalpha=11
- alphalist=np.linspace(0,1,nbalpha)
+ n_alpha = 11
+ alpha_list = np.linspace(0, 1, n_alpha)
- B_l2=np.zeros((n,nbalpha))
+ B_l2 = np.zeros((n, n_alpha))
- B_wass=np.copy(B_l2)
+ B_wass = np.copy(B_l2)
- for i in range(0,nbalpha):
- alpha=alphalist[i]
- weights=np.array([1-alpha,alpha])
- B_l2[:,i]=A.dot(weights)
- B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights)
+ for i in range(0, n_alpha):
+ alpha = alpha_list[i]
+ weights = np.array([1 - alpha, alpha])
+ B_l2[:, i] = A.dot(weights)
+ B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
#%% plot interpolation
- pl.figure(3,(10,5))
+ pl.figure(3)
- #pl.subplot(1,2,1)
- cmap=pl.cm.get_cmap('viridis')
+ cmap = pl.cm.get_cmap('viridis')
verts = []
- zs = alphalist
- for i,z in enumerate(zs):
- ys = B_l2[:,i]
+ zs = alpha_list
+ for i, z in enumerate(zs):
+ ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
ax = pl.gcf().gca(projection='3d')
- poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
+ poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
ax.add_collection3d(poly, zs=zs, zdir='y')
-
ax.set_xlabel('x')
ax.set_xlim3d(0, n)
ax.set_ylabel('$\\alpha$')
- ax.set_ylim3d(0,1)
+ ax.set_ylim3d(0, 1)
ax.set_zlabel('')
- ax.set_zlim3d(0, B_l2.max()*1.01)
+ ax.set_zlim3d(0, B_l2.max() * 1.01)
pl.title('Barycenter interpolation with l2')
+ pl.tight_layout()
- pl.show()
-
- pl.figure(4,(10,5))
-
- #pl.subplot(1,2,1)
- cmap=pl.cm.get_cmap('viridis')
+ pl.figure(4)
+ cmap = pl.cm.get_cmap('viridis')
verts = []
- zs = alphalist
- for i,z in enumerate(zs):
- ys = B_wass[:,i]
+ zs = alpha_list
+ for i, z in enumerate(zs):
+ ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
ax = pl.gcf().gca(projection='3d')
- poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
+ poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
ax.add_collection3d(poly, zs=zs, zdir='y')
-
ax.set_xlabel('x')
ax.set_xlim3d(0, n)
ax.set_ylabel('$\\alpha$')
- ax.set_ylim3d(0,1)
+ ax.set_ylim3d(0, 1)
ax.set_zlabel('')
- ax.set_zlim3d(0, B_l2.max()*1.01)
+ ax.set_zlim3d(0, B_l2.max() * 1.01)
pl.title('Barycenter interpolation with Wasserstein')
+ pl.tight_layout()
pl.show()
-**Total running time of the script:** ( 0 minutes 2.274 seconds)
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_005.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_006.png
+ :scale: 47
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.814 seconds)