summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-12 23:56:27 +0200
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-20 14:08:02 +0200
commit95db977e8b931277af5dadbd79eccbd5fbb8bb62 (patch)
tree4b48e558afceddfcb640e42bb0682bdcccb02b55
parent37ca3142a4c8c808382f5eb1c23bf198c3e4610e (diff)
pep8
-rw-r--r--ot/gpu/bregman.py11
-rw-r--r--ot/gpu/da.py2
-rw-r--r--ot/utils.py62
-rw-r--r--test/test_emd_multi.py27
-rw-r--r--test/test_gpu_sinkhorn.py6
-rw-r--r--test/test_load_module.py4
6 files changed, 61 insertions, 51 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index 2c3e317..7881c65 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -82,7 +82,7 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
- """
+ """
# init data
Nini = len(a)
Nfin = len(b)
@@ -92,11 +92,11 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
# we assume that no distances are null except those of the diagonal of
# distances
- u = (np.ones(Nini)/Nini).reshape((Nini, 1))
+ u = (np.ones(Nini) / Nini).reshape((Nini, 1))
u_GPU = cudamat.CUDAMatrix(u)
a_GPU = cudamat.CUDAMatrix(a.reshape((Nini, 1)))
ones_GPU = cudamat.empty(u_GPU.shape).assign(1)
- v = (np.ones(Nfin)/Nfin).reshape((Nfin, 1))
+ v = (np.ones(Nfin) / Nfin).reshape((Nfin, 1))
v_GPU = cudamat.CUDAMatrix(v)
b_GPU = cudamat.CUDAMatrix(b.reshape((Nfin, 1)))
@@ -121,7 +121,7 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
ones_GPU.divide(Kp_GPU.dot(v_GPU), target=u_GPU)
if (np.any(KtransposeU_GPU.asarray() == 0) or
- not u_GPU.allfinite() or not v_GPU.allfinite()):
+ not u_GPU.allfinite() or not v_GPU.allfinite()):
# we have reached the machine precision
# come back to previous solution and quit loop
print('Warning: numerical errors at iteration', cpt)
@@ -142,7 +142,8 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
if verbose:
if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err')+'\n'+'-'*19)
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
cpt += 1
if log:
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
index b05ff70..c66e755 100644
--- a/ot/gpu/da.py
+++ b/ot/gpu/da.py
@@ -167,7 +167,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M_GPU, reg, eta=0.1, numItermax=10,
tmpC_GPU = cudamat.empty((Nfin, nbRow)).assign(0)
transp_GPU.transpose().select_columns(indices_labels[i], tmpC_GPU)
majs_GPU = tmpC_GPU.sum(axis=1).add(epsilon)
- cudamat.pow(majs_GPU, (p-1))
+ cudamat.pow(majs_GPU, (p - 1))
majs_GPU.mult(p)
tmpC_GPU.assign(0)
diff --git a/ot/utils.py b/ot/utils.py
index 7ad7637..6a43f61 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -7,31 +7,35 @@ from scipy.spatial.distance import cdist
import multiprocessing
import time
-__time_tic_toc=time.time()
+__time_tic_toc = time.time()
+
def tic():
""" Python implementation of Matlab tic() function """
global __time_tic_toc
- __time_tic_toc=time.time()
+ __time_tic_toc = time.time()
+
def toc(message='Elapsed time : {} s'):
""" Python implementation of Matlab toc() function """
- t=time.time()
- print(message.format(t-__time_tic_toc))
- return t-__time_tic_toc
+ t = time.time()
+ print(message.format(t - __time_tic_toc))
+ return t - __time_tic_toc
+
def toq():
""" Python implementation of Julia toc() function """
- t=time.time()
- return t-__time_tic_toc
+ t = time.time()
+ return t - __time_tic_toc
-def kernel(x1,x2,method='gaussian',sigma=1,**kwargs):
+def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
"""Compute kernel matrix"""
- if method.lower() in ['gaussian','gauss','rbf']:
- K=np.exp(-dist(x1,x2)/(2*sigma**2))
+ if method.lower() in ['gaussian', 'gauss', 'rbf']:
+ K = np.exp(-dist(x1, x2) / (2 * sigma**2))
return K
+
def unif(n):
""" return a uniform histogram of length n (simplex)
@@ -48,17 +52,19 @@ def unif(n):
"""
- return np.ones((n,))/n
+ return np.ones((n,)) / n
-def clean_zeros(a,b,M):
- """ Remove all components with zeros weights in a and b
+
+def clean_zeros(a, b, M):
+ """ Remove all components with zeros weights in a and b
"""
- M2=M[a>0,:][:,b>0].copy() # copy force c style matrix (froemd)
- a2=a[a>0]
- b2=b[b>0]
- return a2,b2,M2
+ M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd)
+ a2 = a[a > 0]
+ b2 = b[b > 0]
+ return a2, b2, M2
+
-def dist(x1,x2=None,metric='sqeuclidean'):
+def dist(x1, x2=None, metric='sqeuclidean'):
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
Parameters
@@ -84,12 +90,12 @@ def dist(x1,x2=None,metric='sqeuclidean'):
"""
if x2 is None:
- x2=x1
+ x2 = x1
- return cdist(x1,x2,metric=metric)
+ return cdist(x1, x2, metric=metric)
-def dist0(n,method='lin_square'):
+def dist0(n, method='lin_square'):
"""Compute standard cost matrices of size (n,n) for OT problems
Parameters
@@ -111,16 +117,17 @@ def dist0(n,method='lin_square'):
"""
- res=0
- if method=='lin_square':
- x=np.arange(n,dtype=np.float64).reshape((n,1))
- res=dist(x,x)
+ res = 0
+ if method == 'lin_square':
+ x = np.arange(n, dtype=np.float64).reshape((n, 1))
+ res = dist(x, x)
return res
def dots(*args):
""" dots function for multiple matrix multiply """
- return reduce(np.dot,args)
+ return reduce(np.dot, args)
+
def fun(f, q_in, q_out):
""" Utility function for parmap with no serializing problems """
@@ -130,6 +137,7 @@ def fun(f, q_in, q_out):
break
q_out.put((i, f(x)))
+
def parmap(f, X, nprocs=multiprocessing.cpu_count()):
""" paralell map for multiprocessing """
q_in = multiprocessing.Queue(1)
@@ -147,4 +155,4 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()):
[p.join() for p in proc]
- return [x for i, x in sorted(res)] \ No newline at end of file
+ return [x for i, x in sorted(res)]
diff --git a/test/test_emd_multi.py b/test/test_emd_multi.py
index ee0a20e..99173e9 100644
--- a/test/test_emd_multi.py
+++ b/test/test_emd_multi.py
@@ -7,31 +7,30 @@ Created on Fri Mar 10 09:56:06 2017
"""
import numpy as np
-import pylab as pl
-import ot
+import ot
from ot.datasets import get_1D_gauss as gauss
-reload(ot.lp)
+# reload(ot.lp)
#%% parameters
-n=5000 # nb bins
+n = 5000 # nb bins
# bin positions
-x=np.arange(n,dtype=np.float64)
+x = np.arange(n, dtype=np.float64)
# Gaussian distributions
-a=gauss(n,m=20,s=5) # m= mean, s= std
+a = gauss(n, m=20, s=5) # m= mean, s= std
-ls= range(20,1000,10)
-nb=len(ls)
-b=np.zeros((n,nb))
+ls = range(20, 1000, 10)
+nb = len(ls)
+b = np.zeros((n, nb))
for i in range(nb):
- b[:,i]=gauss(n,m=ls[i],s=10)
+ b[:, i] = gauss(n, m=ls[i], s=10)
# loss matrix
-M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
-#M/=M.max()
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+# M/=M.max()
#%%
@@ -39,10 +38,10 @@ print('Computing {} EMD '.format(nb))
# emd loss 1 proc
ot.tic()
-emd_loss4=ot.emd2(a,b,M,1)
+emd_loss4 = ot.emd2(a, b, M, 1)
ot.toc('1 proc : {} s')
# emd loss multipro proc
ot.tic()
-emd_loss4=ot.emd2(a,b,M)
+emd_loss4 = ot.emd2(a, b, M)
ot.toc('multi proc : {} s')
diff --git a/test/test_gpu_sinkhorn.py b/test/test_gpu_sinkhorn.py
index bfa2cd2..841f062 100644
--- a/test/test_gpu_sinkhorn.py
+++ b/test/test_gpu_sinkhorn.py
@@ -3,8 +3,10 @@ import numpy as np
import time
import ot.gpu
+
def describeRes(r):
- print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format(np.min(r),np.max(r),np.mean(r),np.std(r)))
+ print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format(
+ np.min(r), np.max(r), np.mean(r), np.std(r)))
for n in [5000, 10000, 15000, 20000]:
@@ -23,4 +25,4 @@ for n in [5000, 10000, 15000, 20000]:
print("Normal sinkhorn, time: {:6.2f} sec ".format(time2 - time1))
describeRes(G1)
print(" GPU sinkhorn, time: {:6.2f} sec ".format(time3 - time2))
- describeRes(G2) \ No newline at end of file
+ describeRes(G2)
diff --git a/test/test_load_module.py b/test/test_load_module.py
index a04c5df..d77261e 100644
--- a/test/test_load_module.py
+++ b/test/test_load_module.py
@@ -4,7 +4,7 @@ import ot
import doctest
# test lp solver
-doctest.testmod(ot.lp,verbose=True)
+doctest.testmod(ot.lp, verbose=True)
# test bregman solver
-doctest.testmod(ot.bregman,verbose=True)
+doctest.testmod(ot.bregman, verbose=True)