diff options
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 31 |
1 files changed, 29 insertions, 2 deletions
diff --git a/ot/utils.py b/ot/utils.py index 2f68775..e5cd6c0 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -4,7 +4,7 @@ Various function that can be usefull """ import numpy as np from scipy.spatial.distance import cdist - +import multiprocessing import time __time_tic_toc=time.time() @@ -113,4 +113,31 @@ def dist0(n,method='lin_square'): def dots(*args): """ dots function for multiple matrix multiply """ - return reduce(np.dot,args)
\ No newline at end of file + return reduce(np.dot,args) + +def fun(f, q_in, q_out): + """ Utility function for parmap with no serializing problems """ + while True: + i, x = q_in.get() + if i is None: + break + q_out.put((i, f(x))) + +def parmap(f, X, nprocs=multiprocessing.cpu_count()): + """ paralell map for multiprocessing """ + q_in = multiprocessing.Queue(1) + q_out = multiprocessing.Queue() + + proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) + for _ in range(nprocs)] + for p in proc: + p.daemon = True + p.start() + + sent = [q_in.put((i, x)) for i, x in enumerate(X)] + [q_in.put((None, None)) for _ in range(nprocs)] + res = [q_out.get() for _ in range(len(sent))] + + [p.join() for p in proc] + + return [x for i, x in sorted(res)]
\ No newline at end of file |