From ab65f86304b03a967054eeeaf73b8c8277618d65 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Thu, 7 Sep 2017 14:35:35 +0900 Subject: Added log option to muliprocess emd --- ot/lp/__init__.py | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) (limited to 'ot/lp') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index c15e6b9..8edd8ec 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -7,11 +7,13 @@ Solvers for the original linear program OT problem # # License: MIT License +import multiprocessing + import numpy as np + # import compiled emd from .emd_wrap import emd_c from ..utils import parmap -import multiprocessing def emd(a, b, M, numItermax=100000, log=False): @@ -88,9 +90,9 @@ def emd(a, b, M, numItermax=100000, log=False): # if empty array given then use unifor distributions if len(a) == 0: - a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] G, cost, u, v = emd_c(a, b, M, numItermax) if log: @@ -101,7 +103,8 @@ def emd(a, b, M, numItermax=100000, log=False): return G, log return G -def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): + +def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -168,16 +171,26 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): # if empty array given then use unifor distributions if len(a) == 0: - a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - - if len(b.shape)==1: - return emd_c(a, b, M, numItermax)[1] + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + if log: + def f(b): + G, cost, u, v = emd_c(a, b, M, numItermax) + log = {} + log['G'] = G + log['u'] = u + log['v'] = v + return [cost, log] + else: + def f(b): + return emd_c(a, b, M, numItermax)[1] + + if len(b.shape) == 1: + return f(b) nb = b.shape[1] # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)] - def f(b): - return emd_c(a,b,M, numItermax)[1] - res= parmap(f, [b[:,i] for i in range(nb)],processes) - return np.array(res) + res = parmap(f, [b[:, i] for i in range(nb)], processes) + return res -- cgit v1.2.3