summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-07-05 13:47:43 +0200
committerRémi Flamary <remi.flamary@gmail.com>2019-07-05 13:47:43 +0200
commit7ac1b462d23ae0a396742bba4773e146e60e7502 (patch)
tree82b36ba5f9511c133322e1687120ff8e4c315d8f /ot/lp/__init__.py
parent0bc936f62430c98ecbb0f39c9508f29c6054a327 (diff)
cleanup parmap on windows
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 17f1731..0c92810 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -11,7 +11,7 @@ Solvers for the original linear program OT problem
# License: MIT License
import multiprocessing
-
+import sys
import numpy as np
from scipy.sparse import coo_matrix
@@ -151,6 +151,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
Target histogram (uniform weight if empty list)
M : (ns,nt) numpy.ndarray, float64
Loss matrix (c-order array with type float64)
+ processes : int, optional (default=nb cpu)
+ Nb of processes used for multiple emd computation (not used on windows)
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
@@ -200,6 +202,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
+ # problem with pikling Forks
+ if sys.platform.endswith('win32'):
+ processes=1
+
# if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
@@ -228,7 +234,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
return f(b)
nb = b.shape[1]
- res = parmap(f, [b[:, i] for i in range(nb)], processes)
+ if processes>1:
+ res = parmap(f, [b[:, i] for i in range(nb)], processes)
+ else:
+ res = list(map(f, [b[:, i].copy() for i in range(nb)]))
+
return res