diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-08-30 17:01:01 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-08-30 17:01:01 +0200 |
commit | dc8737a30cb6d9f1305173eb8d16fe6716fd1231 (patch) | |
tree | 1f03384de2af88ed07a1e850e0871db826ed53e7 /docs/source/auto_examples/plot_OT_2D_samples.py | |
parent | c2a7a1f3ab4ba5c4f5adeca0fa22d8d6b4fc079d (diff) |
wroking make!
Diffstat (limited to 'docs/source/auto_examples/plot_OT_2D_samples.py')
-rw-r--r-- | docs/source/auto_examples/plot_OT_2D_samples.py | 57 |
1 files changed, 31 insertions, 26 deletions
diff --git a/docs/source/auto_examples/plot_OT_2D_samples.py b/docs/source/auto_examples/plot_OT_2D_samples.py index edfb781..2a42dc0 100644 --- a/docs/source/auto_examples/plot_OT_2D_samples.py +++ b/docs/source/auto_examples/plot_OT_2D_samples.py @@ -4,57 +4,60 @@ 2D Optimal transport between empirical distributions ==================================================== -@author: rflamary """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + import numpy as np import matplotlib.pylab as pl import ot #%% parameters and data generation -n=50 # nb samples +n = 50 # nb samples -mu_s=np.array([0,0]) -cov_s=np.array([[1,0],[0,1]]) +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) -mu_t=np.array([4,4]) -cov_t=np.array([[1,-.8],[-.8,1]]) +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) -xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s) -xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t) +xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.get_2D_samples_gauss(n, mu_t, cov_t) -a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples # loss matrix -M=ot.dist(xs,xt) -M/=M.max() +M = ot.dist(xs, xt) +M /= M.max() #%% plot samples pl.figure(1) -pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') -pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.legend(loc=0) -pl.title('Source and traget distributions') +pl.title('Source and target distributions') pl.figure(2) -pl.imshow(M,interpolation='nearest') +pl.imshow(M, interpolation='nearest') pl.title('Cost matrix M') #%% EMD -G0=ot.emd(a,b,M) +G0 = ot.emd(a, b, M) pl.figure(3) -pl.imshow(G0,interpolation='nearest') +pl.imshow(G0, interpolation='nearest') pl.title('OT matrix G0') pl.figure(4) -ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) -pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') -pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.legend(loc=0) pl.title('OT matrix with samples') @@ -62,17 +65,19 @@ pl.title('OT matrix with samples') #%% sinkhorn # reg term -lambd=5e-4 +lambd = 1e-3 -Gs=ot.sinkhorn(a,b,M,lambd) +Gs = ot.sinkhorn(a, b, M, lambd) pl.figure(5) -pl.imshow(Gs,interpolation='nearest') +pl.imshow(Gs, interpolation='nearest') pl.title('OT matrix sinkhorn') pl.figure(6) -ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1]) -pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') -pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') +ot.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.legend(loc=0) pl.title('OT matrix Sinkhorn with samples') + +pl.show() |