summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorNicolas Courty <ncourty@irisa.fr>2017-09-13 10:12:36 +0200
committerGitHub <noreply@github.com>2017-09-13 10:12:36 +0200
commitc86cc4feb4003e90c7c3dddba237190b360fc514 (patch)
treef11601b6b1d7ef821a0ae3233960f6f29a1ac0f7 /ot/lp/__init__.py
parent7e5df4cc25e6500ec6f3e85f1c80a7db94863ace (diff)
parenta53ede95f916a11e2150ab7917820d813c0034bc (diff)
Merge branch 'master' into gromov
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py80
1 files changed, 60 insertions, 20 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index de91e74..5c09da2 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -7,14 +7,16 @@ Solvers for the original linear program OT problem
#
# License: MIT License
+import multiprocessing
+
import numpy as np
+
# import compiled emd
-from .emd_wrap import emd_c, emd2_c
+from .emd_wrap import emd_c, check_result
from ..utils import parmap
-import multiprocessing
-def emd(a, b, M, numItermax=100000):
+def emd(a, b, M, numItermax=100000, log=False):
"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -42,11 +44,17 @@ def emd(a, b, M, numItermax=100000):
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost and dual
+ variables. Otherwise returns only the optimal transportation matrix.
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is true, a dictionary containing the cost and dual
+ variables and exit status
Examples
@@ -82,14 +90,24 @@ def emd(a, b, M, numItermax=100000):
# if empty array given then use unifor distributions
if len(a) == 0:
- a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
-
- return emd_c(a, b, M, numItermax)
-
-
-def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ result_code_string = check_result(result_code)
+ if log:
+ log = {}
+ log['cost'] = cost
+ log['u'] = u
+ log['v'] = v
+ log['warning'] = result_code_string
+ log['result_code'] = result_code
+ return G, log
+ return G
+
+
+def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False, return_matrix=False):
"""Solves the Earth Movers distance problem and returns the loss
.. math::
@@ -116,11 +134,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost and dual
+ variables. Otherwise returns only the optimal transportation cost.
+ return_matrix: boolean, optional (default=False)
+ If True, returns the optimal transportation matrix in the log.
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is true, a dictionary containing the cost and dual
+ variables and exit status
Examples
@@ -156,17 +182,31 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
# if empty array given then use unifor distributions
if len(a) == 0:
- a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
- if len(b.shape) == 1:
- return emd2_c(a, b, M, numItermax)
+ if log or return_matrix:
+ def f(b):
+ G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
+ result_code_string = check_result(resultCode)
+ log = {}
+ if return_matrix:
+ log['G'] = G
+ log['u'] = u
+ log['v'] = v
+ log['warning'] = result_code_string
+ log['result_code'] = resultCode
+ return [cost, log]
else:
- nb = b.shape[1]
- # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
-
def f(b):
- return emd2_c(a, b, M, numItermax)
- res = parmap(f, [b[:, i] for i in range(nb)], processes)
- return np.array(res)
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ check_result(result_code)
+ return cost
+
+ if len(b.shape) == 1:
+ return f(b)
+ nb = b.shape[1]
+
+ res = parmap(f, [b[:, i] for i in range(nb)], processes)
+ return res