summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-03-14 10:30:45 +0100
committerRémi Flamary <remi.flamary@gmail.com>2017-03-14 10:30:45 +0100
commita84f2c3e23edd1fa89975bd77b08672f518d5ca4 (patch)
tree1aadf05357949e6daec2c332eb900e93346ad465 /ot/lp/__init__.py
parent84219d9bd87acd9bbb6d1a832cf4ccaee53fed0b (diff)
add emd2+ multiproc
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py44
1 files changed, 10 insertions, 34 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 57c5b7b..5674cf6 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -5,9 +5,12 @@ Solvers for the original linear program OT problem
import numpy as np
# import compiled emd
-from .emd import emd_c
+from .emd import emd_c, emd2_c
+from ..utils import parmap
import multiprocessing
+
+
def emd(a, b, M):
"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -145,41 +148,14 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()):
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
if len(b.shape)==1:
- return np.sum(emd_c(a, b, M)*M)
+ return emd2_c(a, b, M)
else:
nb=b.shape[1]
- ls=[(a,b[:,k],M) for k in range(nb)]
- def f(l):
- return emd2(l[0],l[1],l[2])
- # run emd in multiprocessing
- res=parmap(f, ls,processes)
+ #res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)]
+ def f(b):
+ return emd2_c(a,b,M)
+ res= parmap(f, [b[:,i] for i in range(nb)],processes)
return np.array(res)
-# with Pool(processes) as p:
-# res=p.map(f, ls)
-# return np.array(res)
-def fun(f, q_in, q_out):
- while True:
- i, x = q_in.get()
- if i is None:
- break
- q_out.put((i, f(x)))
-
-def parmap(f, X, nprocs=multiprocessing.cpu_count()):
- q_in = multiprocessing.Queue(1)
- q_out = multiprocessing.Queue()
-
- proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
- for _ in range(nprocs)]
- for p in proc:
- p.daemon = True
- p.start()
-
- sent = [q_in.put((i, x)) for i, x in enumerate(X)]
- [q_in.put((None, None)) for _ in range(nprocs)]
- res = [q_out.get() for _ in range(len(sent))]
-
- [p.join() for p in proc]
-
- return [x for i, x in sorted(res)] \ No newline at end of file
+ \ No newline at end of file