diff options
author | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2017-07-11 21:43:22 +0200 |
---|---|---|
committer | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2017-07-20 14:05:12 +0200 |
commit | 35b25adf1fd9ec35fb6f6da105e268ed5b467d79 (patch) | |
tree | c35f4a69a008e138847b2b8fea692e6d3d23cf9d /examples | |
parent | 95b2a584d02da1a08e71f7ff3895d958e42ed2dc (diff) |
pimp + pep8 on plot_OT_2D_samples
Diffstat (limited to 'examples')
-rw-r--r-- | examples/plot_OT_2D_samples.py | 82 |
1 files changed, 42 insertions, 40 deletions
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index edfb781..3a93591 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -8,71 +8,73 @@ """ import numpy as np -import matplotlib.pylab as pl +import matplotlib.pylab as plt import ot #%% parameters and data generation -n=50 # nb samples +n = 50 # nb samples -mu_s=np.array([0,0]) -cov_s=np.array([[1,0],[0,1]]) +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) -mu_t=np.array([4,4]) -cov_t=np.array([[1,-.8],[-.8,1]]) +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) -xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s) -xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t) +xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.get_2D_samples_gauss(n, mu_t, cov_t) -a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples # loss matrix -M=ot.dist(xs,xt) -M/=M.max() +M = ot.dist(xs, xt) +M /= M.max() #%% plot samples -pl.figure(1) -pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') -pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') -pl.legend(loc=0) -pl.title('Source and traget distributions') +plt.figure(1) +plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +plt.legend(loc=0) +plt.title('Source and target distributions') -pl.figure(2) -pl.imshow(M,interpolation='nearest') -pl.title('Cost matrix M') +plt.figure(2) +plt.imshow(M, interpolation='nearest') +plt.title('Cost matrix M') #%% EMD -G0=ot.emd(a,b,M) +G0 = ot.emd(a, b, M) -pl.figure(3) -pl.imshow(G0,interpolation='nearest') -pl.title('OT matrix G0') +plt.figure(3) +plt.imshow(G0, interpolation='nearest') +plt.title('OT matrix G0') -pl.figure(4) -ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) -pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') -pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') -pl.legend(loc=0) -pl.title('OT matrix with samples') +plt.figure(4) +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) +plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +plt.legend(loc=0) +plt.title('OT matrix with samples') #%% sinkhorn # reg term -lambd=5e-4 +lambd = 5e-4 -Gs=ot.sinkhorn(a,b,M,lambd) +Gs = ot.sinkhorn(a, b, M, lambd) -pl.figure(5) -pl.imshow(Gs,interpolation='nearest') -pl.title('OT matrix sinkhorn') +plt.figure(5) +plt.imshow(Gs, interpolation='nearest') +plt.title('OT matrix sinkhorn') -pl.figure(6) -ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1]) -pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') -pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') -pl.legend(loc=0) -pl.title('OT matrix Sinkhorn with samples') +plt.figure(6) +ot.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1]) +plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +plt.legend(loc=0) +plt.title('OT matrix Sinkhorn with samples') + +plt.show() |