summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-07-05 13:47:43 +0200
committerRémi Flamary <remi.flamary@gmail.com>2019-07-05 13:47:43 +0200
commit7ac1b462d23ae0a396742bba4773e146e60e7502 (patch)
tree82b36ba5f9511c133322e1687120ff8e4c315d8f
parent0bc936f62430c98ecbb0f39c9508f29c6054a327 (diff)
cleanup parmap on windows
-rw-r--r--.travis.yml51
-rw-r--r--ot/lp/__init__.py14
-rw-r--r--ot/utils.py31
-rw-r--r--requirements.txt1
4 files changed, 57 insertions, 40 deletions
diff --git a/.travis.yml b/.travis.yml
index 67f0c43..cddf0e0 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,31 +1,32 @@
dist: xenial # required for Python >= 3.7
language: python
matrix:
-# allow_failures:
-# - os: osx
- include:
-# - os: osx
-# language: generic
- - os: linux
- sudo: required
- python: 3.4
- - os: linux
- sudo: required
- python: 3.5
- - os: linux
- sudo: required
- python: 3.6
- - os: linux
- sudo: required
- python: 3.7
- - os: linux
- sudo: required
- python: 2.7
- - name: "Python 3.7.3 on Windows"
- os: windows # Windows 10.0.17134 N/A Build 17134
- language: shell # 'language: python' is an error on Travis CI Windows
- before_install: choco install python
- env: PATH=/c/Python37:/c/Python37/Scripts:$PATH
+ allow_failures:
+ - os: osx
+ - os: windows
+ include:
+ - os: osx
+ language: generic
+ - os: linux
+ sudo: required
+ python: 3.4
+ - os: linux
+ sudo: required
+ python: 3.5
+ - os: linux
+ sudo: required
+ python: 3.6
+ - os: linux
+ sudo: required
+ python: 3.7
+ - os: linux
+ sudo: required
+ python: 2.7
+ - name: "Python 3.7.3 on Windows"
+ os: windows # Windows 10.0.17134 N/A Build 17134
+ language: shell # 'language: python' is an error on Travis CI Windows
+ before_install: choco install python
+ env: PATH=/c/Python37:/c/Python37/Scripts:$PATH
before_install:
- ./.travis/before_install.sh
# command to install dependencies
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 17f1731..0c92810 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -11,7 +11,7 @@ Solvers for the original linear program OT problem
# License: MIT License
import multiprocessing
-
+import sys
import numpy as np
from scipy.sparse import coo_matrix
@@ -151,6 +151,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
Target histogram (uniform weight if empty list)
M : (ns,nt) numpy.ndarray, float64
Loss matrix (c-order array with type float64)
+ processes : int, optional (default=nb cpu)
+ Nb of processes used for multiple emd computation (not used on windows)
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
@@ -200,6 +202,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
+ # problem with pikling Forks
+ if sys.platform.endswith('win32'):
+ processes=1
+
# if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
@@ -228,7 +234,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
return f(b)
nb = b.shape[1]
- res = parmap(f, [b[:, i] for i in range(nb)], processes)
+ if processes>1:
+ res = parmap(f, [b[:, i] for i in range(nb)], processes)
+ else:
+ res = list(map(f, [b[:, i].copy() for i in range(nb)]))
+
return res
diff --git a/ot/utils.py b/ot/utils.py
index e8249ef..5707d9b 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -214,23 +214,28 @@ def fun(f, q_in, q_out):
def parmap(f, X, nprocs=multiprocessing.cpu_count()):
- """ paralell map for multiprocessing """
- q_in = multiprocessing.Queue(1)
- q_out = multiprocessing.Queue()
+ """ paralell map for multiprocessing (only map on windows)"""
- proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
- for _ in range(nprocs)]
- for p in proc:
- p.daemon = True
- p.start()
+ if not sys.platform.endswith('win32'):
- sent = [q_in.put((i, x)) for i, x in enumerate(X)]
- [q_in.put((None, None)) for _ in range(nprocs)]
- res = [q_out.get() for _ in range(len(sent))]
+ q_in = multiprocessing.Queue(1)
+ q_out = multiprocessing.Queue()
- [p.join() for p in proc]
+ proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
+ for _ in range(nprocs)]
+ for p in proc:
+ p.daemon = True
+ p.start()
- return [x for i, x in sorted(res)]
+ sent = [q_in.put((i, x)) for i, x in enumerate(X)]
+ [q_in.put((None, None)) for _ in range(nprocs)]
+ res = [q_out.get() for _ in range(len(sent))]
+
+ [p.join() for p in proc]
+
+ return [x for i, x in sorted(res)]
+ else:
+ return list(map(f, X))
def check_params(**kwargs):
diff --git a/requirements.txt b/requirements.txt
index 97d165b..5a3432b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,4 +5,5 @@ matplotlib
sphinx-gallery
autograd
pymanopt
+cvxopt
pytest