From a84f2c3e23edd1fa89975bd77b08672f518d5ca4 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Tue, 14 Mar 2017 10:30:45 +0100 Subject: add emd2+ multiproc --- ot/lp/__init__.py | 44 ++++++++++---------------------------------- 1 file changed, 10 insertions(+), 34 deletions(-) (limited to 'ot/lp/__init__.py') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 57c5b7b..5674cf6 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -5,9 +5,12 @@ Solvers for the original linear program OT problem import numpy as np # import compiled emd -from .emd import emd_c +from .emd import emd_c, emd2_c +from ..utils import parmap import multiprocessing + + def emd(a, b, M): """Solves the Earth Movers distance problem and returns the OT matrix @@ -145,41 +148,14 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()): b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] if len(b.shape)==1: - return np.sum(emd_c(a, b, M)*M) + return emd2_c(a, b, M) else: nb=b.shape[1] - ls=[(a,b[:,k],M) for k in range(nb)] - def f(l): - return emd2(l[0],l[1],l[2]) - # run emd in multiprocessing - res=parmap(f, ls,processes) + #res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)] + def f(b): + return emd2_c(a,b,M) + res= parmap(f, [b[:,i] for i in range(nb)],processes) return np.array(res) -# with Pool(processes) as p: -# res=p.map(f, ls) -# return np.array(res) -def fun(f, q_in, q_out): - while True: - i, x = q_in.get() - if i is None: - break - q_out.put((i, f(x))) - -def parmap(f, X, nprocs=multiprocessing.cpu_count()): - q_in = multiprocessing.Queue(1) - q_out = multiprocessing.Queue() - - proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) - for _ in range(nprocs)] - for p in proc: - p.daemon = True - p.start() - - sent = [q_in.put((i, x)) for i, x in enumerate(X)] - [q_in.put((None, None)) for _ in range(nprocs)] - res = [q_out.get() for _ in range(len(sent))] - - [p.join() for p in proc] - - return [x for i, x in sorted(res)] \ No newline at end of file + \ No newline at end of file -- cgit v1.2.3