diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-03-14 10:30:45 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-03-14 10:30:45 +0100 |
commit | a84f2c3e23edd1fa89975bd77b08672f518d5ca4 (patch) | |
tree | 1aadf05357949e6daec2c332eb900e93346ad465 /ot/utils.py | |
parent | 84219d9bd87acd9bbb6d1a832cf4ccaee53fed0b (diff) |
add emd2+ multiproc
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 31 |
1 files changed, 29 insertions, 2 deletions
diff --git a/ot/utils.py b/ot/utils.py index 2f68775..e5cd6c0 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -4,7 +4,7 @@ Various function that can be usefull """ import numpy as np from scipy.spatial.distance import cdist - +import multiprocessing import time __time_tic_toc=time.time() @@ -113,4 +113,31 @@ def dist0(n,method='lin_square'): def dots(*args): """ dots function for multiple matrix multiply """ - return reduce(np.dot,args)
\ No newline at end of file + return reduce(np.dot,args) + +def fun(f, q_in, q_out): + """ Utility function for parmap with no serializing problems """ + while True: + i, x = q_in.get() + if i is None: + break + q_out.put((i, f(x))) + +def parmap(f, X, nprocs=multiprocessing.cpu_count()): + """ paralell map for multiprocessing """ + q_in = multiprocessing.Queue(1) + q_out = multiprocessing.Queue() + + proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) + for _ in range(nprocs)] + for p in proc: + p.daemon = True + p.start() + + sent = [q_in.put((i, x)) for i, x in enumerate(X)] + [q_in.put((None, None)) for _ in range(nprocs)] + res = [q_out.get() for _ in range(len(sent))] + + [p.join() for p in proc] + + return [x for i, x in sorted(res)]
\ No newline at end of file |