summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-03-10 10:26:18 +0100
committerRémi Flamary <remi.flamary@gmail.com>2017-03-10 10:26:18 +0100
commit84219d9bd87acd9bbb6d1a832cf4ccaee53fed0b (patch)
treec1877b043ef51556da52b1d92a5287e8f6c376b7 /ot/lp/__init__.py
parentdf32d77316e79a663312544129048f8fee949817 (diff)
runs but not quicker
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 5358083..57c5b7b 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -78,7 +78,7 @@ def emd(a, b, M):
return emd_c(a, b, M)
-def emd2(a, b, M,processes=None):
+def emd2(a, b, M,processes=multiprocessing.cpu_count()):
"""Solves the Earth Movers distance problem and returns the loss
.. math::
@@ -149,9 +149,11 @@ def emd2(a, b, M,processes=None):
else:
nb=b.shape[1]
ls=[(a,b[:,k],M) for k in range(nb)]
+ def f(l):
+ return emd2(l[0],l[1],l[2])
# run emd in multiprocessing
- res=parmap(emd2, ls,processes)
- np.array(res)
+ res=parmap(f, ls,processes)
+ return np.array(res)
# with Pool(processes) as p:
# res=p.map(f, ls)
# return np.array(res)
@@ -164,7 +166,7 @@ def fun(f, q_in, q_out):
break
q_out.put((i, f(x)))
-def parmap(f, X, nprocs):
+def parmap(f, X, nprocs=multiprocessing.cpu_count()):
q_in = multiprocessing.Queue(1)
q_out = multiprocessing.Queue()