summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-03-14 10:30:45 +0100
committerRémi Flamary <remi.flamary@gmail.com>2017-03-14 10:30:45 +0100
commita84f2c3e23edd1fa89975bd77b08672f518d5ca4 (patch)
tree1aadf05357949e6daec2c332eb900e93346ad465 /ot/utils.py
parent84219d9bd87acd9bbb6d1a832cf4ccaee53fed0b (diff)
add emd2+ multiproc
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py31
1 files changed, 29 insertions, 2 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 2f68775..e5cd6c0 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -4,7 +4,7 @@ Various function that can be usefull
"""
import numpy as np
from scipy.spatial.distance import cdist
-
+import multiprocessing
import time
__time_tic_toc=time.time()
@@ -113,4 +113,31 @@ def dist0(n,method='lin_square'):
def dots(*args):
""" dots function for multiple matrix multiply """
- return reduce(np.dot,args) \ No newline at end of file
+ return reduce(np.dot,args)
+
+def fun(f, q_in, q_out):
+ """ Utility function for parmap with no serializing problems """
+ while True:
+ i, x = q_in.get()
+ if i is None:
+ break
+ q_out.put((i, f(x)))
+
+def parmap(f, X, nprocs=multiprocessing.cpu_count()):
+ """ paralell map for multiprocessing """
+ q_in = multiprocessing.Queue(1)
+ q_out = multiprocessing.Queue()
+
+ proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
+ for _ in range(nprocs)]
+ for p in proc:
+ p.daemon = True
+ p.start()
+
+ sent = [q_in.put((i, x)) for i, x in enumerate(X)]
+ [q_in.put((None, None)) for _ in range(nprocs)]
+ res = [q_out.get() for _ in range(len(sent))]
+
+ [p.join() for p in proc]
+
+ return [x for i, x in sorted(res)] \ No newline at end of file