diff options
-rw-r--r-- | .travis.yml | 51 | ||||
-rw-r--r-- | ot/lp/__init__.py | 14 | ||||
-rw-r--r-- | ot/utils.py | 31 | ||||
-rw-r--r-- | requirements.txt | 1 |
4 files changed, 57 insertions, 40 deletions
diff --git a/.travis.yml b/.travis.yml index 67f0c43..cddf0e0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,31 +1,32 @@ dist: xenial # required for Python >= 3.7 language: python matrix: -# allow_failures: -# - os: osx - include: -# - os: osx -# language: generic - - os: linux - sudo: required - python: 3.4 - - os: linux - sudo: required - python: 3.5 - - os: linux - sudo: required - python: 3.6 - - os: linux - sudo: required - python: 3.7 - - os: linux - sudo: required - python: 2.7 - - name: "Python 3.7.3 on Windows" - os: windows # Windows 10.0.17134 N/A Build 17134 - language: shell # 'language: python' is an error on Travis CI Windows - before_install: choco install python - env: PATH=/c/Python37:/c/Python37/Scripts:$PATH + allow_failures: + - os: osx + - os: windows + include: + - os: osx + language: generic + - os: linux + sudo: required + python: 3.4 + - os: linux + sudo: required + python: 3.5 + - os: linux + sudo: required + python: 3.6 + - os: linux + sudo: required + python: 3.7 + - os: linux + sudo: required + python: 2.7 + - name: "Python 3.7.3 on Windows" + os: windows # Windows 10.0.17134 N/A Build 17134 + language: shell # 'language: python' is an error on Travis CI Windows + before_install: choco install python + env: PATH=/c/Python37:/c/Python37/Scripts:$PATH before_install: - ./.travis/before_install.sh # command to install dependencies diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 17f1731..0c92810 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -11,7 +11,7 @@ Solvers for the original linear program OT problem # License: MIT License import multiprocessing - +import sys import numpy as np from scipy.sparse import coo_matrix @@ -151,6 +151,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), Target histogram (uniform weight if empty list) M : (ns,nt) numpy.ndarray, float64 Loss matrix (c-order array with type float64) + processes : int, optional (default=nb cpu) + Nb of processes used for multiple emd computation (not used on windows) numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. @@ -200,6 +202,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) + # problem with pikling Forks + if sys.platform.endswith('win32'): + processes=1 + # if empty array given then use uniform distributions if len(a) == 0: a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] @@ -228,7 +234,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), return f(b) nb = b.shape[1] - res = parmap(f, [b[:, i] for i in range(nb)], processes) + if processes>1: + res = parmap(f, [b[:, i] for i in range(nb)], processes) + else: + res = list(map(f, [b[:, i].copy() for i in range(nb)])) + return res diff --git a/ot/utils.py b/ot/utils.py index e8249ef..5707d9b 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -214,23 +214,28 @@ def fun(f, q_in, q_out): def parmap(f, X, nprocs=multiprocessing.cpu_count()): - """ paralell map for multiprocessing """ - q_in = multiprocessing.Queue(1) - q_out = multiprocessing.Queue() + """ paralell map for multiprocessing (only map on windows)""" - proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) - for _ in range(nprocs)] - for p in proc: - p.daemon = True - p.start() + if not sys.platform.endswith('win32'): - sent = [q_in.put((i, x)) for i, x in enumerate(X)] - [q_in.put((None, None)) for _ in range(nprocs)] - res = [q_out.get() for _ in range(len(sent))] + q_in = multiprocessing.Queue(1) + q_out = multiprocessing.Queue() - [p.join() for p in proc] + proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) + for _ in range(nprocs)] + for p in proc: + p.daemon = True + p.start() - return [x for i, x in sorted(res)] + sent = [q_in.put((i, x)) for i, x in enumerate(X)] + [q_in.put((None, None)) for _ in range(nprocs)] + res = [q_out.get() for _ in range(len(sent))] + + [p.join() for p in proc] + + return [x for i, x in sorted(res)] + else: + return list(map(f, X)) def check_params(**kwargs): diff --git a/requirements.txt b/requirements.txt index 97d165b..5a3432b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ matplotlib sphinx-gallery autograd pymanopt +cvxopt pytest |