summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_barycenter_1D.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-08-30 17:01:01 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-08-30 17:01:01 +0200
commitdc8737a30cb6d9f1305173eb8d16fe6716fd1231 (patch)
tree1f03384de2af88ed07a1e850e0871db826ed53e7 /docs/source/auto_examples/plot_barycenter_1D.py
parentc2a7a1f3ab4ba5c4f5adeca0fa22d8d6b4fc079d (diff)
wroking make!
Diffstat (limited to 'docs/source/auto_examples/plot_barycenter_1D.py')
-rw-r--r--docs/source/auto_examples/plot_barycenter_1D.py117
1 files changed, 58 insertions, 59 deletions
diff --git a/docs/source/auto_examples/plot_barycenter_1D.py b/docs/source/auto_examples/plot_barycenter_1D.py
index 30eecbf..875f44c 100644
--- a/docs/source/auto_examples/plot_barycenter_1D.py
+++ b/docs/source/auto_examples/plot_barycenter_1D.py
@@ -4,135 +4,134 @@
1D Wasserstein barycenter demo
==============================
-
-@author: rflamary
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
import matplotlib.pylab as pl
import ot
-from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
from matplotlib.collections import PolyCollection
#%% parameters
-n=100 # nb bins
+n = 100 # nb bins
# bin positions
-x=np.arange(n,dtype=np.float64)
+x = np.arange(n, dtype=np.float64)
# Gaussian distributions
-a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std
-a2=ot.datasets.get_1D_gauss(n,m=60,s=8)
+a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.get_1D_gauss(n, m=60, s=8)
# creating matrix A containing all distributions
-A=np.vstack((a1,a2)).T
-nbd=A.shape[1]
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
# loss matrix + normalization
-M=ot.utils.dist0(n)
-M/=M.max()
+M = ot.utils.dist0(n)
+M /= M.max()
#%% plot the distributions
-pl.figure(1)
-for i in range(nbd):
- pl.plot(x,A[:,i])
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
pl.title('Distributions')
+pl.tight_layout()
#%% barycenter computation
-alpha=0.2 # 0<=alpha<=1
-weights=np.array([1-alpha,alpha])
+alpha = 0.2 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
# l2bary
-bary_l2=A.dot(weights)
+bary_l2 = A.dot(weights)
# wasserstein
-reg=1e-3
-bary_wass=ot.bregman.barycenter(A,M,reg,weights)
+reg = 1e-3
+bary_wass = ot.bregman.barycenter(A, M, reg, weights)
pl.figure(2)
pl.clf()
-pl.subplot(2,1,1)
-for i in range(nbd):
- pl.plot(x,A[:,i])
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
pl.title('Distributions')
-pl.subplot(2,1,2)
-pl.plot(x,bary_l2,'r',label='l2')
-pl.plot(x,bary_wass,'g',label='Wasserstein')
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Wasserstein')
pl.legend()
pl.title('Barycenters')
-
+pl.tight_layout()
#%% barycenter interpolation
-nbalpha=11
-alphalist=np.linspace(0,1,nbalpha)
+n_alpha = 11
+alpha_list = np.linspace(0, 1, n_alpha)
-B_l2=np.zeros((n,nbalpha))
+B_l2 = np.zeros((n, n_alpha))
-B_wass=np.copy(B_l2)
+B_wass = np.copy(B_l2)
-for i in range(0,nbalpha):
- alpha=alphalist[i]
- weights=np.array([1-alpha,alpha])
- B_l2[:,i]=A.dot(weights)
- B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights)
+for i in range(0, n_alpha):
+ alpha = alpha_list[i]
+ weights = np.array([1 - alpha, alpha])
+ B_l2[:, i] = A.dot(weights)
+ B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
#%% plot interpolation
-pl.figure(3,(10,5))
+pl.figure(3)
-#pl.subplot(1,2,1)
-cmap=pl.cm.get_cmap('viridis')
+cmap = pl.cm.get_cmap('viridis')
verts = []
-zs = alphalist
-for i,z in enumerate(zs):
- ys = B_l2[:,i]
+zs = alpha_list
+for i, z in enumerate(zs):
+ ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
ax = pl.gcf().gca(projection='3d')
-poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
ax.add_collection3d(poly, zs=zs, zdir='y')
-
ax.set_xlabel('x')
ax.set_xlim3d(0, n)
ax.set_ylabel('$\\alpha$')
-ax.set_ylim3d(0,1)
+ax.set_ylim3d(0, 1)
ax.set_zlabel('')
-ax.set_zlim3d(0, B_l2.max()*1.01)
+ax.set_zlim3d(0, B_l2.max() * 1.01)
pl.title('Barycenter interpolation with l2')
+pl.tight_layout()
-pl.show()
-
-pl.figure(4,(10,5))
-
-#pl.subplot(1,2,1)
-cmap=pl.cm.get_cmap('viridis')
+pl.figure(4)
+cmap = pl.cm.get_cmap('viridis')
verts = []
-zs = alphalist
-for i,z in enumerate(zs):
- ys = B_wass[:,i]
+zs = alpha_list
+for i, z in enumerate(zs):
+ ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
ax = pl.gcf().gca(projection='3d')
-poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
ax.add_collection3d(poly, zs=zs, zdir='y')
-
ax.set_xlabel('x')
ax.set_xlim3d(0, n)
ax.set_ylabel('$\\alpha$')
-ax.set_ylim3d(0,1)
+ax.set_ylim3d(0, 1)
ax.set_zlabel('')
-ax.set_zlim3d(0, B_l2.max()*1.01)
+ax.set_zlim3d(0, B_l2.max() * 1.01)
pl.title('Barycenter interpolation with Wasserstein')
+pl.tight_layout()
-pl.show() \ No newline at end of file
+pl.show()