diff options
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 5358083..57c5b7b 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -78,7 +78,7 @@ def emd(a, b, M): return emd_c(a, b, M) -def emd2(a, b, M,processes=None): +def emd2(a, b, M,processes=multiprocessing.cpu_count()): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -149,9 +149,11 @@ def emd2(a, b, M,processes=None): else: nb=b.shape[1] ls=[(a,b[:,k],M) for k in range(nb)] + def f(l): + return emd2(l[0],l[1],l[2]) # run emd in multiprocessing - res=parmap(emd2, ls,processes) - np.array(res) + res=parmap(f, ls,processes) + return np.array(res) # with Pool(processes) as p: # res=p.map(f, ls) # return np.array(res) @@ -164,7 +166,7 @@ def fun(f, q_in, q_out): break q_out.put((i, f(x))) -def parmap(f, X, nprocs): +def parmap(f, X, nprocs=multiprocessing.cpu_count()): q_in = multiprocessing.Queue(1) q_out = multiprocessing.Queue() |