diff options
author | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2017-07-12 23:56:27 +0200 |
---|---|---|
committer | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2017-07-20 14:08:02 +0200 |
commit | 95db977e8b931277af5dadbd79eccbd5fbb8bb62 (patch) | |
tree | 4b48e558afceddfcb640e42bb0682bdcccb02b55 | |
parent | 37ca3142a4c8c808382f5eb1c23bf198c3e4610e (diff) |
pep8
-rw-r--r-- | ot/gpu/bregman.py | 11 | ||||
-rw-r--r-- | ot/gpu/da.py | 2 | ||||
-rw-r--r-- | ot/utils.py | 62 | ||||
-rw-r--r-- | test/test_emd_multi.py | 27 | ||||
-rw-r--r-- | test/test_gpu_sinkhorn.py | 6 | ||||
-rw-r--r-- | test/test_load_module.py | 4 |
6 files changed, 61 insertions, 51 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index 2c3e317..7881c65 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -82,7 +82,7 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False, ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT - """ + """ # init data Nini = len(a) Nfin = len(b) @@ -92,11 +92,11 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False, # we assume that no distances are null except those of the diagonal of # distances - u = (np.ones(Nini)/Nini).reshape((Nini, 1)) + u = (np.ones(Nini) / Nini).reshape((Nini, 1)) u_GPU = cudamat.CUDAMatrix(u) a_GPU = cudamat.CUDAMatrix(a.reshape((Nini, 1))) ones_GPU = cudamat.empty(u_GPU.shape).assign(1) - v = (np.ones(Nfin)/Nfin).reshape((Nfin, 1)) + v = (np.ones(Nfin) / Nfin).reshape((Nfin, 1)) v_GPU = cudamat.CUDAMatrix(v) b_GPU = cudamat.CUDAMatrix(b.reshape((Nfin, 1))) @@ -121,7 +121,7 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False, ones_GPU.divide(Kp_GPU.dot(v_GPU), target=u_GPU) if (np.any(KtransposeU_GPU.asarray() == 0) or - not u_GPU.allfinite() or not v_GPU.allfinite()): + not u_GPU.allfinite() or not v_GPU.allfinite()): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) @@ -142,7 +142,8 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False, if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err')+'\n'+'-'*19) + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err)) cpt += 1 if log: diff --git a/ot/gpu/da.py b/ot/gpu/da.py index b05ff70..c66e755 100644 --- a/ot/gpu/da.py +++ b/ot/gpu/da.py @@ -167,7 +167,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M_GPU, reg, eta=0.1, numItermax=10, tmpC_GPU = cudamat.empty((Nfin, nbRow)).assign(0) transp_GPU.transpose().select_columns(indices_labels[i], tmpC_GPU) majs_GPU = tmpC_GPU.sum(axis=1).add(epsilon) - cudamat.pow(majs_GPU, (p-1)) + cudamat.pow(majs_GPU, (p - 1)) majs_GPU.mult(p) tmpC_GPU.assign(0) diff --git a/ot/utils.py b/ot/utils.py index 7ad7637..6a43f61 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -7,31 +7,35 @@ from scipy.spatial.distance import cdist import multiprocessing import time -__time_tic_toc=time.time() +__time_tic_toc = time.time() + def tic(): """ Python implementation of Matlab tic() function """ global __time_tic_toc - __time_tic_toc=time.time() + __time_tic_toc = time.time() + def toc(message='Elapsed time : {} s'): """ Python implementation of Matlab toc() function """ - t=time.time() - print(message.format(t-__time_tic_toc)) - return t-__time_tic_toc + t = time.time() + print(message.format(t - __time_tic_toc)) + return t - __time_tic_toc + def toq(): """ Python implementation of Julia toc() function """ - t=time.time() - return t-__time_tic_toc + t = time.time() + return t - __time_tic_toc -def kernel(x1,x2,method='gaussian',sigma=1,**kwargs): +def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): """Compute kernel matrix""" - if method.lower() in ['gaussian','gauss','rbf']: - K=np.exp(-dist(x1,x2)/(2*sigma**2)) + if method.lower() in ['gaussian', 'gauss', 'rbf']: + K = np.exp(-dist(x1, x2) / (2 * sigma**2)) return K + def unif(n): """ return a uniform histogram of length n (simplex) @@ -48,17 +52,19 @@ def unif(n): """ - return np.ones((n,))/n + return np.ones((n,)) / n -def clean_zeros(a,b,M): - """ Remove all components with zeros weights in a and b + +def clean_zeros(a, b, M): + """ Remove all components with zeros weights in a and b """ - M2=M[a>0,:][:,b>0].copy() # copy force c style matrix (froemd) - a2=a[a>0] - b2=b[b>0] - return a2,b2,M2 + M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd) + a2 = a[a > 0] + b2 = b[b > 0] + return a2, b2, M2 + -def dist(x1,x2=None,metric='sqeuclidean'): +def dist(x1, x2=None, metric='sqeuclidean'): """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist Parameters @@ -84,12 +90,12 @@ def dist(x1,x2=None,metric='sqeuclidean'): """ if x2 is None: - x2=x1 + x2 = x1 - return cdist(x1,x2,metric=metric) + return cdist(x1, x2, metric=metric) -def dist0(n,method='lin_square'): +def dist0(n, method='lin_square'): """Compute standard cost matrices of size (n,n) for OT problems Parameters @@ -111,16 +117,17 @@ def dist0(n,method='lin_square'): """ - res=0 - if method=='lin_square': - x=np.arange(n,dtype=np.float64).reshape((n,1)) - res=dist(x,x) + res = 0 + if method == 'lin_square': + x = np.arange(n, dtype=np.float64).reshape((n, 1)) + res = dist(x, x) return res def dots(*args): """ dots function for multiple matrix multiply """ - return reduce(np.dot,args) + return reduce(np.dot, args) + def fun(f, q_in, q_out): """ Utility function for parmap with no serializing problems """ @@ -130,6 +137,7 @@ def fun(f, q_in, q_out): break q_out.put((i, f(x))) + def parmap(f, X, nprocs=multiprocessing.cpu_count()): """ paralell map for multiprocessing """ q_in = multiprocessing.Queue(1) @@ -147,4 +155,4 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()): [p.join() for p in proc] - return [x for i, x in sorted(res)]
\ No newline at end of file + return [x for i, x in sorted(res)] diff --git a/test/test_emd_multi.py b/test/test_emd_multi.py index ee0a20e..99173e9 100644 --- a/test/test_emd_multi.py +++ b/test/test_emd_multi.py @@ -7,31 +7,30 @@ Created on Fri Mar 10 09:56:06 2017 """ import numpy as np -import pylab as pl -import ot +import ot from ot.datasets import get_1D_gauss as gauss -reload(ot.lp) +# reload(ot.lp) #%% parameters -n=5000 # nb bins +n = 5000 # nb bins # bin positions -x=np.arange(n,dtype=np.float64) +x = np.arange(n, dtype=np.float64) # Gaussian distributions -a=gauss(n,m=20,s=5) # m= mean, s= std +a = gauss(n, m=20, s=5) # m= mean, s= std -ls= range(20,1000,10) -nb=len(ls) -b=np.zeros((n,nb)) +ls = range(20, 1000, 10) +nb = len(ls) +b = np.zeros((n, nb)) for i in range(nb): - b[:,i]=gauss(n,m=ls[i],s=10) + b[:, i] = gauss(n, m=ls[i], s=10) # loss matrix -M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) -#M/=M.max() +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +# M/=M.max() #%% @@ -39,10 +38,10 @@ print('Computing {} EMD '.format(nb)) # emd loss 1 proc ot.tic() -emd_loss4=ot.emd2(a,b,M,1) +emd_loss4 = ot.emd2(a, b, M, 1) ot.toc('1 proc : {} s') # emd loss multipro proc ot.tic() -emd_loss4=ot.emd2(a,b,M) +emd_loss4 = ot.emd2(a, b, M) ot.toc('multi proc : {} s') diff --git a/test/test_gpu_sinkhorn.py b/test/test_gpu_sinkhorn.py index bfa2cd2..841f062 100644 --- a/test/test_gpu_sinkhorn.py +++ b/test/test_gpu_sinkhorn.py @@ -3,8 +3,10 @@ import numpy as np import time import ot.gpu + def describeRes(r): - print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format(np.min(r),np.max(r),np.mean(r),np.std(r))) + print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format( + np.min(r), np.max(r), np.mean(r), np.std(r))) for n in [5000, 10000, 15000, 20000]: @@ -23,4 +25,4 @@ for n in [5000, 10000, 15000, 20000]: print("Normal sinkhorn, time: {:6.2f} sec ".format(time2 - time1)) describeRes(G1) print(" GPU sinkhorn, time: {:6.2f} sec ".format(time3 - time2)) - describeRes(G2)
\ No newline at end of file + describeRes(G2) diff --git a/test/test_load_module.py b/test/test_load_module.py index a04c5df..d77261e 100644 --- a/test/test_load_module.py +++ b/test/test_load_module.py @@ -4,7 +4,7 @@ import ot import doctest # test lp solver -doctest.testmod(ot.lp,verbose=True) +doctest.testmod(ot.lp, verbose=True) # test bregman solver -doctest.testmod(ot.bregman,verbose=True) +doctest.testmod(ot.bregman, verbose=True) |