summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorAntoine Rolet <antoine.rolet@gmail.com>2017-09-07 14:35:35 +0900
committerAntoine Rolet <antoine.rolet@gmail.com>2017-09-07 14:35:35 +0900
commitab65f86304b03a967054eeeaf73b8c8277618d65 (patch)
treeafdf3a385588277dfe32f8ac92d5009f14ea0c4e /ot/lp/__init__.py
parent12d9b3ff72e9669ccc0162e82b7a33beb51d3e25 (diff)
Added log option to muliprocess emd
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py39
1 files changed, 26 insertions, 13 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index c15e6b9..8edd8ec 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -7,11 +7,13 @@ Solvers for the original linear program OT problem
#
# License: MIT License
+import multiprocessing
+
import numpy as np
+
# import compiled emd
from .emd_wrap import emd_c
from ..utils import parmap
-import multiprocessing
def emd(a, b, M, numItermax=100000, log=False):
@@ -88,9 +90,9 @@ def emd(a, b, M, numItermax=100000, log=False):
# if empty array given then use unifor distributions
if len(a) == 0:
- a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
G, cost, u, v = emd_c(a, b, M, numItermax)
if log:
@@ -101,7 +103,8 @@ def emd(a, b, M, numItermax=100000, log=False):
return G, log
return G
-def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
+
+def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False):
"""Solves the Earth Movers distance problem and returns the loss
.. math::
@@ -168,16 +171,26 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
# if empty array given then use unifor distributions
if len(a) == 0:
- a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
-
- if len(b.shape)==1:
- return emd_c(a, b, M, numItermax)[1]
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
+ if log:
+ def f(b):
+ G, cost, u, v = emd_c(a, b, M, numItermax)
+ log = {}
+ log['G'] = G
+ log['u'] = u
+ log['v'] = v
+ return [cost, log]
+ else:
+ def f(b):
+ return emd_c(a, b, M, numItermax)[1]
+
+ if len(b.shape) == 1:
+ return f(b)
nb = b.shape[1]
# res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
- def f(b):
- return emd_c(a,b,M, numItermax)[1]
- res= parmap(f, [b[:,i] for i in range(nb)],processes)
- return np.array(res)
+ res = parmap(f, [b[:, i] for i in range(nb)], processes)
+ return res