diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-08-30 17:01:01 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-08-30 17:01:01 +0200 |
commit | dc8737a30cb6d9f1305173eb8d16fe6716fd1231 (patch) | |
tree | 1f03384de2af88ed07a1e850e0871db826ed53e7 /docs/source/auto_examples/plot_barycenter_1D.rst | |
parent | c2a7a1f3ab4ba5c4f5adeca0fa22d8d6b4fc079d (diff) |
wroking make!
Diffstat (limited to 'docs/source/auto_examples/plot_barycenter_1D.rst')
-rw-r--r-- | docs/source/auto_examples/plot_barycenter_1D.rst | 118 |
1 files changed, 59 insertions, 59 deletions
diff --git a/docs/source/auto_examples/plot_barycenter_1D.rst b/docs/source/auto_examples/plot_barycenter_1D.rst index 1b15c77..af88e80 100644 --- a/docs/source/auto_examples/plot_barycenter_1D.rst +++ b/docs/source/auto_examples/plot_barycenter_1D.rst @@ -8,8 +8,6 @@ ============================== -@author: rflamary - @@ -43,135 +41,137 @@ .. code-block:: python + # Author: Remi Flamary <remi.flamary@unice.fr> + # + # License: MIT License + import numpy as np import matplotlib.pylab as pl import ot - from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used + # necessary for 3d plot even if not used + from mpl_toolkits.mplot3d import Axes3D # noqa from matplotlib.collections import PolyCollection #%% parameters - n=100 # nb bins + n = 100 # nb bins # bin positions - x=np.arange(n,dtype=np.float64) + x = np.arange(n, dtype=np.float64) # Gaussian distributions - a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std - a2=ot.datasets.get_1D_gauss(n,m=60,s=8) + a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std + a2 = ot.datasets.get_1D_gauss(n, m=60, s=8) # creating matrix A containing all distributions - A=np.vstack((a1,a2)).T - nbd=A.shape[1] + A = np.vstack((a1, a2)).T + n_distributions = A.shape[1] # loss matrix + normalization - M=ot.utils.dist0(n) - M/=M.max() + M = ot.utils.dist0(n) + M /= M.max() #%% plot the distributions - pl.figure(1) - for i in range(nbd): - pl.plot(x,A[:,i]) + pl.figure(1, figsize=(6.4, 3)) + for i in range(n_distributions): + pl.plot(x, A[:, i]) pl.title('Distributions') + pl.tight_layout() #%% barycenter computation - alpha=0.2 # 0<=alpha<=1 - weights=np.array([1-alpha,alpha]) + alpha = 0.2 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) # l2bary - bary_l2=A.dot(weights) + bary_l2 = A.dot(weights) # wasserstein - reg=1e-3 - bary_wass=ot.bregman.barycenter(A,M,reg,weights) + reg = 1e-3 + bary_wass = ot.bregman.barycenter(A, M, reg, weights) pl.figure(2) pl.clf() - pl.subplot(2,1,1) - for i in range(nbd): - pl.plot(x,A[:,i]) + pl.subplot(2, 1, 1) + for i in range(n_distributions): + pl.plot(x, A[:, i]) pl.title('Distributions') - pl.subplot(2,1,2) - pl.plot(x,bary_l2,'r',label='l2') - pl.plot(x,bary_wass,'g',label='Wasserstein') + pl.subplot(2, 1, 2) + pl.plot(x, bary_l2, 'r', label='l2') + pl.plot(x, bary_wass, 'g', label='Wasserstein') pl.legend() pl.title('Barycenters') - + pl.tight_layout() #%% barycenter interpolation - nbalpha=11 - alphalist=np.linspace(0,1,nbalpha) + n_alpha = 11 + alpha_list = np.linspace(0, 1, n_alpha) - B_l2=np.zeros((n,nbalpha)) + B_l2 = np.zeros((n, n_alpha)) - B_wass=np.copy(B_l2) + B_wass = np.copy(B_l2) - for i in range(0,nbalpha): - alpha=alphalist[i] - weights=np.array([1-alpha,alpha]) - B_l2[:,i]=A.dot(weights) - B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights) + for i in range(0, n_alpha): + alpha = alpha_list[i] + weights = np.array([1 - alpha, alpha]) + B_l2[:, i] = A.dot(weights) + B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights) #%% plot interpolation - pl.figure(3,(10,5)) + pl.figure(3) - #pl.subplot(1,2,1) - cmap=pl.cm.get_cmap('viridis') + cmap = pl.cm.get_cmap('viridis') verts = [] - zs = alphalist - for i,z in enumerate(zs): - ys = B_l2[:,i] + zs = alpha_list + for i, z in enumerate(zs): + ys = B_l2[:, i] verts.append(list(zip(x, ys))) ax = pl.gcf().gca(projection='3d') - poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) + poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) ax.add_collection3d(poly, zs=zs, zdir='y') - ax.set_xlabel('x') ax.set_xlim3d(0, n) ax.set_ylabel('$\\alpha$') - ax.set_ylim3d(0,1) + ax.set_ylim3d(0, 1) ax.set_zlabel('') - ax.set_zlim3d(0, B_l2.max()*1.01) + ax.set_zlim3d(0, B_l2.max() * 1.01) pl.title('Barycenter interpolation with l2') + pl.tight_layout() - pl.show() - - pl.figure(4,(10,5)) - - #pl.subplot(1,2,1) - cmap=pl.cm.get_cmap('viridis') + pl.figure(4) + cmap = pl.cm.get_cmap('viridis') verts = [] - zs = alphalist - for i,z in enumerate(zs): - ys = B_wass[:,i] + zs = alpha_list + for i, z in enumerate(zs): + ys = B_wass[:, i] verts.append(list(zip(x, ys))) ax = pl.gcf().gca(projection='3d') - poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) + poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) ax.add_collection3d(poly, zs=zs, zdir='y') - ax.set_xlabel('x') ax.set_xlim3d(0, n) ax.set_ylabel('$\\alpha$') - ax.set_ylim3d(0,1) + ax.set_ylim3d(0, 1) ax.set_zlabel('') - ax.set_zlim3d(0, B_l2.max()*1.01) + ax.set_zlim3d(0, B_l2.max() * 1.01) pl.title('Barycenter interpolation with Wasserstein') + pl.tight_layout() pl.show() -**Total running time of the script:** ( 0 minutes 2.274 seconds) + +**Total running time of the script:** ( 0 minutes 0.546 seconds) |