summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-12 23:56:27 +0200
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-20 14:08:02 +0200
commit95db977e8b931277af5dadbd79eccbd5fbb8bb62 (patch)
tree4b48e558afceddfcb640e42bb0682bdcccb02b55 /ot/utils.py
parent37ca3142a4c8c808382f5eb1c23bf198c3e4610e (diff)
pep8
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py62
1 files changed, 35 insertions, 27 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 7ad7637..6a43f61 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -7,31 +7,35 @@ from scipy.spatial.distance import cdist
import multiprocessing
import time
-__time_tic_toc=time.time()
+__time_tic_toc = time.time()
+
def tic():
""" Python implementation of Matlab tic() function """
global __time_tic_toc
- __time_tic_toc=time.time()
+ __time_tic_toc = time.time()
+
def toc(message='Elapsed time : {} s'):
""" Python implementation of Matlab toc() function """
- t=time.time()
- print(message.format(t-__time_tic_toc))
- return t-__time_tic_toc
+ t = time.time()
+ print(message.format(t - __time_tic_toc))
+ return t - __time_tic_toc
+
def toq():
""" Python implementation of Julia toc() function """
- t=time.time()
- return t-__time_tic_toc
+ t = time.time()
+ return t - __time_tic_toc
-def kernel(x1,x2,method='gaussian',sigma=1,**kwargs):
+def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
"""Compute kernel matrix"""
- if method.lower() in ['gaussian','gauss','rbf']:
- K=np.exp(-dist(x1,x2)/(2*sigma**2))
+ if method.lower() in ['gaussian', 'gauss', 'rbf']:
+ K = np.exp(-dist(x1, x2) / (2 * sigma**2))
return K
+
def unif(n):
""" return a uniform histogram of length n (simplex)
@@ -48,17 +52,19 @@ def unif(n):
"""
- return np.ones((n,))/n
+ return np.ones((n,)) / n
-def clean_zeros(a,b,M):
- """ Remove all components with zeros weights in a and b
+
+def clean_zeros(a, b, M):
+ """ Remove all components with zeros weights in a and b
"""
- M2=M[a>0,:][:,b>0].copy() # copy force c style matrix (froemd)
- a2=a[a>0]
- b2=b[b>0]
- return a2,b2,M2
+ M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd)
+ a2 = a[a > 0]
+ b2 = b[b > 0]
+ return a2, b2, M2
+
-def dist(x1,x2=None,metric='sqeuclidean'):
+def dist(x1, x2=None, metric='sqeuclidean'):
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
Parameters
@@ -84,12 +90,12 @@ def dist(x1,x2=None,metric='sqeuclidean'):
"""
if x2 is None:
- x2=x1
+ x2 = x1
- return cdist(x1,x2,metric=metric)
+ return cdist(x1, x2, metric=metric)
-def dist0(n,method='lin_square'):
+def dist0(n, method='lin_square'):
"""Compute standard cost matrices of size (n,n) for OT problems
Parameters
@@ -111,16 +117,17 @@ def dist0(n,method='lin_square'):
"""
- res=0
- if method=='lin_square':
- x=np.arange(n,dtype=np.float64).reshape((n,1))
- res=dist(x,x)
+ res = 0
+ if method == 'lin_square':
+ x = np.arange(n, dtype=np.float64).reshape((n, 1))
+ res = dist(x, x)
return res
def dots(*args):
""" dots function for multiple matrix multiply """
- return reduce(np.dot,args)
+ return reduce(np.dot, args)
+
def fun(f, q_in, q_out):
""" Utility function for parmap with no serializing problems """
@@ -130,6 +137,7 @@ def fun(f, q_in, q_out):
break
q_out.put((i, f(x)))
+
def parmap(f, X, nprocs=multiprocessing.cpu_count()):
""" paralell map for multiprocessing """
q_in = multiprocessing.Queue(1)
@@ -147,4 +155,4 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()):
[p.join() for p in proc]
- return [x for i, x in sorted(res)] \ No newline at end of file
+ return [x for i, x in sorted(res)]