summaryrefslogtreecommitdiff
path: root/test/test_emd_multi.py
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-12 23:56:27 +0200
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-20 14:08:02 +0200
commit95db977e8b931277af5dadbd79eccbd5fbb8bb62 (patch)
tree4b48e558afceddfcb640e42bb0682bdcccb02b55 /test/test_emd_multi.py
parent37ca3142a4c8c808382f5eb1c23bf198c3e4610e (diff)
pep8
Diffstat (limited to 'test/test_emd_multi.py')
-rw-r--r--test/test_emd_multi.py27
1 files changed, 13 insertions, 14 deletions
diff --git a/test/test_emd_multi.py b/test/test_emd_multi.py
index ee0a20e..99173e9 100644
--- a/test/test_emd_multi.py
+++ b/test/test_emd_multi.py
@@ -7,31 +7,30 @@ Created on Fri Mar 10 09:56:06 2017
"""
import numpy as np
-import pylab as pl
-import ot
+import ot
from ot.datasets import get_1D_gauss as gauss
-reload(ot.lp)
+# reload(ot.lp)
#%% parameters
-n=5000 # nb bins
+n = 5000 # nb bins
# bin positions
-x=np.arange(n,dtype=np.float64)
+x = np.arange(n, dtype=np.float64)
# Gaussian distributions
-a=gauss(n,m=20,s=5) # m= mean, s= std
+a = gauss(n, m=20, s=5) # m= mean, s= std
-ls= range(20,1000,10)
-nb=len(ls)
-b=np.zeros((n,nb))
+ls = range(20, 1000, 10)
+nb = len(ls)
+b = np.zeros((n, nb))
for i in range(nb):
- b[:,i]=gauss(n,m=ls[i],s=10)
+ b[:, i] = gauss(n, m=ls[i], s=10)
# loss matrix
-M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
-#M/=M.max()
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+# M/=M.max()
#%%
@@ -39,10 +38,10 @@ print('Computing {} EMD '.format(nb))
# emd loss 1 proc
ot.tic()
-emd_loss4=ot.emd2(a,b,M,1)
+emd_loss4 = ot.emd2(a, b, M, 1)
ot.toc('1 proc : {} s')
# emd loss multipro proc
ot.tic()
-emd_loss4=ot.emd2(a,b,M)
+emd_loss4 = ot.emd2(a, b, M)
ot.toc('multi proc : {} s')