summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorAntoine Rolet <antoine.rolet@gmail.com>2017-09-07 13:50:41 +0900
committerAntoine Rolet <antoine.rolet@gmail.com>2017-09-07 13:50:41 +0900
commit12d9b3ff72e9669ccc0162e82b7a33beb51d3e25 (patch)
tree72a2908e9d0e67f7e8499d7ef9aca1246528a980 /ot/lp/__init__.py
parentf8c1c8740f9974dcf4aaf191851d62149dceb91c (diff)
Return dual variables in an optional dictionary
Also removed some code duplication
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py24
1 files changed, 17 insertions, 7 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 6048f60..c15e6b9 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -9,12 +9,12 @@ Solvers for the original linear program OT problem
import numpy as np
# import compiled emd
-from .emd_wrap import emd_c, emd2_c
+from .emd_wrap import emd_c
from ..utils import parmap
import multiprocessing
-def emd(a, b, M, numItermax=100000, dual_variables=False):
+def emd(a, b, M, numItermax=100000, log=False):
"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -42,11 +42,17 @@ def emd(a, b, M, numItermax=100000, dual_variables=False):
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost and dual
+ variables. Otherwise returns only the optimal transportation matrix.
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is true, a dictionary containing the cost and dual
+ variables
Examples
@@ -86,9 +92,13 @@ def emd(a, b, M, numItermax=100000, dual_variables=False):
if len(b) == 0:
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
- G, alpha, beta = emd_c(a, b, M, numItermax)
- if dual_variables:
- return G, alpha, beta
+ G, cost, u, v = emd_c(a, b, M, numItermax)
+ if log:
+ log = {}
+ log['cost'] = cost
+ log['u'] = u
+ log['v'] = v
+ return G, log
return G
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
@@ -163,11 +173,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
if len(b.shape)==1:
- return emd2_c(a, b, M, numItermax)[0]
+ return emd_c(a, b, M, numItermax)[1]
nb = b.shape[1]
# res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
def f(b):
- return emd2_c(a,b,M, numItermax)[0]
+ return emd_c(a,b,M, numItermax)[1]
res= parmap(f, [b[:,i] for i in range(nb)],processes)
return np.array(res)