diff options
author | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2017-07-12 22:13:43 +0200 |
---|---|---|
committer | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2017-07-20 14:05:12 +0200 |
commit | de5903618d2d525e48e3f5ffc205c3f566f0095d (patch) | |
tree | af27845f39ae5a71e0b9d53a211c69fc4703d459 | |
parent | 75c988f515f0a1ee51f88f5fc429a1301a1ca8c5 (diff) |
do plot_compute_emd
-rw-r--r-- | examples/plot_compute_emd.py | 59 |
1 files changed, 31 insertions, 28 deletions
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py index f2cdc35..558facb 100644 --- a/examples/plot_compute_emd.py +++ b/examples/plot_compute_emd.py @@ -15,60 +15,63 @@ from ot.datasets import get_1D_gauss as gauss #%% parameters -n=100 # nb bins -n_target=50 # nb target distributions +n = 100 # nb bins +n_target = 50 # nb target distributions # bin positions -x=np.arange(n,dtype=np.float64) +x = np.arange(n, dtype=np.float64) -lst_m=np.linspace(20,90,n_target) +lst_m = np.linspace(20, 90, n_target) # Gaussian distributions -a=gauss(n,m=20,s=5) # m= mean, s= std +a = gauss(n, m=20, s=5) # m= mean, s= std -B=np.zeros((n,n_target)) +B = np.zeros((n, n_target)) -for i,m in enumerate(lst_m): - B[:,i]=gauss(n,m=m,s=5) +for i, m in enumerate(lst_m): + B[:, i] = gauss(n, m=m, s=5) # loss matrix and normalization -M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean') -M/=M.max() -M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean') -M2/=M2.max() +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean') +M /= M.max() +M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean') +M2 /= M2.max() #%% plot the distributions pl.figure(1) -pl.subplot(2,1,1) -pl.plot(x,a,'b',label='Source distribution') +pl.subplot(2, 1, 1) +pl.plot(x, a, 'b', label='Source distribution') pl.title('Source distribution') -pl.subplot(2,1,2) -pl.plot(x,B,label='Target distributions') +pl.subplot(2, 1, 2) +pl.plot(x, B, label='Target distributions') pl.title('Target distributions') +pl.tight_layout() #%% Compute and plot distributions and loss matrix -d_emd=ot.emd2(a,B,M) # direct computation of EMD -d_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3 +d_emd = ot.emd2(a, B, M) # direct computation of EMD +d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M3 pl.figure(2) -pl.plot(d_emd,label='Euclidean EMD') -pl.plot(d_emd2,label='Squared Euclidean EMD') +pl.plot(d_emd, label='Euclidean EMD') +pl.plot(d_emd2, label='Squared Euclidean EMD') pl.title('EMD distances') pl.legend() #%% -reg=1e-2 -d_sinkhorn=ot.sinkhorn2(a,B,M,reg) -d_sinkhorn2=ot.sinkhorn2(a,B,M2,reg) +reg = 1e-2 +d_sinkhorn = ot.sinkhorn2(a, B, M, reg) +d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg) pl.figure(2) pl.clf() -pl.plot(d_emd,label='Euclidean EMD') -pl.plot(d_emd2,label='Squared Euclidean EMD') -pl.plot(d_sinkhorn,'+',label='Euclidean Sinkhorn') -pl.plot(d_sinkhorn2,'+',label='Squared Euclidean Sinkhorn') +pl.plot(d_emd, label='Euclidean EMD') +pl.plot(d_emd2, label='Squared Euclidean EMD') +pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn') +pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn') pl.title('EMD distances') -pl.legend()
\ No newline at end of file +pl.legend() + +pl.show() |