summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-03-10 10:18:12 +0100
committerRémi Flamary <remi.flamary@gmail.com>2017-03-10 10:18:12 +0100
commitdf32d77316e79a663312544129048f8fee949817 (patch)
treedb5fe129c08268e0aa72a36b19dfebd3705de540 /ot/lp/__init__.py
parent0b806374d33ae83d39846096a1838b096c0c0b8e (diff)
first try
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py107
1 files changed, 106 insertions, 1 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 1e55f5a..5358083 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -6,7 +6,7 @@ Solvers for the original linear program OT problem
import numpy as np
# import compiled emd
from .emd import emd_c
-
+import multiprocessing
def emd(a, b, M):
"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -70,9 +70,114 @@ def emd(a, b, M):
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
+ # if empty array given then use unifor distributions
if len(a) == 0:
a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
if len(b) == 0:
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
return emd_c(a, b, M)
+
+def emd2(a, b, M,processes=None):
+ """Solves the Earth Movers distance problem and returns the loss
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
+ s.t. \gamma 1 = a
+ \gamma^T 1= b
+ \gamma\geq 0
+ where :
+
+ - M is the metric cost matrix
+ - a and b are the sample weights
+
+ Uses the algorithm proposed in [1]_
+
+ Parameters
+ ----------
+ a : (ns,) ndarray, float64
+ Source histogram (uniform weigth if empty list)
+ b : (nt,) ndarray, float64
+ Target histogram (uniform weigth if empty list)
+ M : (ns,nt) ndarray, float64
+ loss matrix
+
+ Returns
+ -------
+ gamma: (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd accepts lists and
+ perform automatic conversion to numpy arrays
+ >>> import ot
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.emd2(a,b,M)
+ 0.0
+
+ References
+ ----------
+
+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
+ (2011, December). Displacement interpolation using Lagrangian mass
+ transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
+ 158). ACM.
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT"""
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ # if empty array given then use unifor distributions
+ if len(a) == 0:
+ a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
+
+ if len(b.shape)==1:
+ return np.sum(emd_c(a, b, M)*M)
+ else:
+ nb=b.shape[1]
+ ls=[(a,b[:,k],M) for k in range(nb)]
+ # run emd in multiprocessing
+ res=parmap(emd2, ls,processes)
+ np.array(res)
+# with Pool(processes) as p:
+# res=p.map(f, ls)
+# return np.array(res)
+
+
+def fun(f, q_in, q_out):
+ while True:
+ i, x = q_in.get()
+ if i is None:
+ break
+ q_out.put((i, f(x)))
+
+def parmap(f, X, nprocs):
+ q_in = multiprocessing.Queue(1)
+ q_out = multiprocessing.Queue()
+
+ proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
+ for _ in range(nprocs)]
+ for p in proc:
+ p.daemon = True
+ p.start()
+
+ sent = [q_in.put((i, x)) for i, x in enumerate(X)]
+ [q_in.put((None, None)) for _ in range(nprocs)]
+ res = [q_out.get() for _ in range(len(sent))]
+
+ [p.join() for p in proc]
+
+ return [x for i, x in sorted(res)] \ No newline at end of file