summaryrefslogtreecommitdiff
path: root/examples/plot_compute_emd.py
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-12 22:13:43 +0200
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-20 14:05:12 +0200
commitde5903618d2d525e48e3f5ffc205c3f566f0095d (patch)
treeaf27845f39ae5a71e0b9d53a211c69fc4703d459 /examples/plot_compute_emd.py
parent75c988f515f0a1ee51f88f5fc429a1301a1ca8c5 (diff)
do plot_compute_emd
Diffstat (limited to 'examples/plot_compute_emd.py')
-rw-r--r--examples/plot_compute_emd.py59
1 files changed, 31 insertions, 28 deletions
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py
index f2cdc35..558facb 100644
--- a/examples/plot_compute_emd.py
+++ b/examples/plot_compute_emd.py
@@ -15,60 +15,63 @@ from ot.datasets import get_1D_gauss as gauss
#%% parameters
-n=100 # nb bins
-n_target=50 # nb target distributions
+n = 100 # nb bins
+n_target = 50 # nb target distributions
# bin positions
-x=np.arange(n,dtype=np.float64)
+x = np.arange(n, dtype=np.float64)
-lst_m=np.linspace(20,90,n_target)
+lst_m = np.linspace(20, 90, n_target)
# Gaussian distributions
-a=gauss(n,m=20,s=5) # m= mean, s= std
+a = gauss(n, m=20, s=5) # m= mean, s= std
-B=np.zeros((n,n_target))
+B = np.zeros((n, n_target))
-for i,m in enumerate(lst_m):
- B[:,i]=gauss(n,m=m,s=5)
+for i, m in enumerate(lst_m):
+ B[:, i] = gauss(n, m=m, s=5)
# loss matrix and normalization
-M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean')
-M/=M.max()
-M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean')
-M2/=M2.max()
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
+M /= M.max()
+M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
+M2 /= M2.max()
#%% plot the distributions
pl.figure(1)
-pl.subplot(2,1,1)
-pl.plot(x,a,'b',label='Source distribution')
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'b', label='Source distribution')
pl.title('Source distribution')
-pl.subplot(2,1,2)
-pl.plot(x,B,label='Target distributions')
+pl.subplot(2, 1, 2)
+pl.plot(x, B, label='Target distributions')
pl.title('Target distributions')
+pl.tight_layout()
#%% Compute and plot distributions and loss matrix
-d_emd=ot.emd2(a,B,M) # direct computation of EMD
-d_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3
+d_emd = ot.emd2(a, B, M) # direct computation of EMD
+d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M3
pl.figure(2)
-pl.plot(d_emd,label='Euclidean EMD')
-pl.plot(d_emd2,label='Squared Euclidean EMD')
+pl.plot(d_emd, label='Euclidean EMD')
+pl.plot(d_emd2, label='Squared Euclidean EMD')
pl.title('EMD distances')
pl.legend()
#%%
-reg=1e-2
-d_sinkhorn=ot.sinkhorn2(a,B,M,reg)
-d_sinkhorn2=ot.sinkhorn2(a,B,M2,reg)
+reg = 1e-2
+d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
+d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
pl.figure(2)
pl.clf()
-pl.plot(d_emd,label='Euclidean EMD')
-pl.plot(d_emd2,label='Squared Euclidean EMD')
-pl.plot(d_sinkhorn,'+',label='Euclidean Sinkhorn')
-pl.plot(d_sinkhorn2,'+',label='Squared Euclidean Sinkhorn')
+pl.plot(d_emd, label='Euclidean EMD')
+pl.plot(d_emd2, label='Squared Euclidean EMD')
+pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')
+pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')
pl.title('EMD distances')
-pl.legend() \ No newline at end of file
+pl.legend()
+
+pl.show()