diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2019-07-05 13:47:43 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2019-07-05 13:47:43 +0200 |
commit | 7ac1b462d23ae0a396742bba4773e146e60e7502 (patch) | |
tree | 82b36ba5f9511c133322e1687120ff8e4c315d8f /ot/lp | |
parent | 0bc936f62430c98ecbb0f39c9508f29c6054a327 (diff) |
cleanup parmap on windows
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/__init__.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 17f1731..0c92810 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -11,7 +11,7 @@ Solvers for the original linear program OT problem # License: MIT License import multiprocessing - +import sys import numpy as np from scipy.sparse import coo_matrix @@ -151,6 +151,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), Target histogram (uniform weight if empty list) M : (ns,nt) numpy.ndarray, float64 Loss matrix (c-order array with type float64) + processes : int, optional (default=nb cpu) + Nb of processes used for multiple emd computation (not used on windows) numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. @@ -200,6 +202,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) + # problem with pikling Forks + if sys.platform.endswith('win32'): + processes=1 + # if empty array given then use uniform distributions if len(a) == 0: a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] @@ -228,7 +234,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), return f(b) nb = b.shape[1] - res = parmap(f, [b[:, i] for i in range(nb)], processes) + if processes>1: + res = parmap(f, [b[:, i] for i in range(nb)], processes) + else: + res = list(map(f, [b[:, i].copy() for i in range(nb)])) + return res |