From 7ac1b462d23ae0a396742bba4773e146e60e7502 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Fri, 5 Jul 2019 13:47:43 +0200 Subject: cleanup parmap on windows --- ot/lp/__init__.py | 14 ++++++++++++-- ot/utils.py | 31 ++++++++++++++++++------------- 2 files changed, 30 insertions(+), 15 deletions(-) (limited to 'ot') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 17f1731..0c92810 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -11,7 +11,7 @@ Solvers for the original linear program OT problem # License: MIT License import multiprocessing - +import sys import numpy as np from scipy.sparse import coo_matrix @@ -151,6 +151,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), Target histogram (uniform weight if empty list) M : (ns,nt) numpy.ndarray, float64 Loss matrix (c-order array with type float64) + processes : int, optional (default=nb cpu) + Nb of processes used for multiple emd computation (not used on windows) numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. @@ -200,6 +202,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) + # problem with pikling Forks + if sys.platform.endswith('win32'): + processes=1 + # if empty array given then use uniform distributions if len(a) == 0: a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] @@ -228,7 +234,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), return f(b) nb = b.shape[1] - res = parmap(f, [b[:, i] for i in range(nb)], processes) + if processes>1: + res = parmap(f, [b[:, i] for i in range(nb)], processes) + else: + res = list(map(f, [b[:, i].copy() for i in range(nb)])) + return res diff --git a/ot/utils.py b/ot/utils.py index e8249ef..5707d9b 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -214,23 +214,28 @@ def fun(f, q_in, q_out): def parmap(f, X, nprocs=multiprocessing.cpu_count()): - """ paralell map for multiprocessing """ - q_in = multiprocessing.Queue(1) - q_out = multiprocessing.Queue() + """ paralell map for multiprocessing (only map on windows)""" - proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) - for _ in range(nprocs)] - for p in proc: - p.daemon = True - p.start() + if not sys.platform.endswith('win32'): - sent = [q_in.put((i, x)) for i, x in enumerate(X)] - [q_in.put((None, None)) for _ in range(nprocs)] - res = [q_out.get() for _ in range(len(sent))] + q_in = multiprocessing.Queue(1) + q_out = multiprocessing.Queue() - [p.join() for p in proc] + proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) + for _ in range(nprocs)] + for p in proc: + p.daemon = True + p.start() - return [x for i, x in sorted(res)] + sent = [q_in.put((i, x)) for i, x in enumerate(X)] + [q_in.put((None, None)) for _ in range(nprocs)] + res = [q_out.get() for _ in range(len(sent))] + + [p.join() for p in proc] + + return [x for i, x in sorted(res)] + else: + return list(map(f, X)) def check_params(**kwargs): -- cgit v1.2.3