diff options
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 62 |
1 files changed, 35 insertions, 27 deletions
diff --git a/ot/utils.py b/ot/utils.py index 7ad7637..6a43f61 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -7,31 +7,35 @@ from scipy.spatial.distance import cdist import multiprocessing import time -__time_tic_toc=time.time() +__time_tic_toc = time.time() + def tic(): """ Python implementation of Matlab tic() function """ global __time_tic_toc - __time_tic_toc=time.time() + __time_tic_toc = time.time() + def toc(message='Elapsed time : {} s'): """ Python implementation of Matlab toc() function """ - t=time.time() - print(message.format(t-__time_tic_toc)) - return t-__time_tic_toc + t = time.time() + print(message.format(t - __time_tic_toc)) + return t - __time_tic_toc + def toq(): """ Python implementation of Julia toc() function """ - t=time.time() - return t-__time_tic_toc + t = time.time() + return t - __time_tic_toc -def kernel(x1,x2,method='gaussian',sigma=1,**kwargs): +def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): """Compute kernel matrix""" - if method.lower() in ['gaussian','gauss','rbf']: - K=np.exp(-dist(x1,x2)/(2*sigma**2)) + if method.lower() in ['gaussian', 'gauss', 'rbf']: + K = np.exp(-dist(x1, x2) / (2 * sigma**2)) return K + def unif(n): """ return a uniform histogram of length n (simplex) @@ -48,17 +52,19 @@ def unif(n): """ - return np.ones((n,))/n + return np.ones((n,)) / n -def clean_zeros(a,b,M): - """ Remove all components with zeros weights in a and b + +def clean_zeros(a, b, M): + """ Remove all components with zeros weights in a and b """ - M2=M[a>0,:][:,b>0].copy() # copy force c style matrix (froemd) - a2=a[a>0] - b2=b[b>0] - return a2,b2,M2 + M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd) + a2 = a[a > 0] + b2 = b[b > 0] + return a2, b2, M2 + -def dist(x1,x2=None,metric='sqeuclidean'): +def dist(x1, x2=None, metric='sqeuclidean'): """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist Parameters @@ -84,12 +90,12 @@ def dist(x1,x2=None,metric='sqeuclidean'): """ if x2 is None: - x2=x1 + x2 = x1 - return cdist(x1,x2,metric=metric) + return cdist(x1, x2, metric=metric) -def dist0(n,method='lin_square'): +def dist0(n, method='lin_square'): """Compute standard cost matrices of size (n,n) for OT problems Parameters @@ -111,16 +117,17 @@ def dist0(n,method='lin_square'): """ - res=0 - if method=='lin_square': - x=np.arange(n,dtype=np.float64).reshape((n,1)) - res=dist(x,x) + res = 0 + if method == 'lin_square': + x = np.arange(n, dtype=np.float64).reshape((n, 1)) + res = dist(x, x) return res def dots(*args): """ dots function for multiple matrix multiply """ - return reduce(np.dot,args) + return reduce(np.dot, args) + def fun(f, q_in, q_out): """ Utility function for parmap with no serializing problems """ @@ -130,6 +137,7 @@ def fun(f, q_in, q_out): break q_out.put((i, f(x))) + def parmap(f, X, nprocs=multiprocessing.cpu_count()): """ paralell map for multiprocessing """ q_in = multiprocessing.Queue(1) @@ -147,4 +155,4 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()): [p.join() for p in proc] - return [x for i, x in sorted(res)]
\ No newline at end of file + return [x for i, x in sorted(res)] |