diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2017-09-13 08:11:00 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-09-13 08:11:00 +0200 |
commit | a53ede95f916a11e2150ab7917820d813c0034bc (patch) | |
tree | 24304d83267d51b962e18722553973bbc75509f2 /ot/lp/__init__.py | |
parent | 62dcfbfb78a2be24379cd5cdb4aec70d8c4befaa (diff) | |
parent | e52b6eb41228a7f8e381cf73c06e0dffba5773be (diff) |
Merge pull request #29 from arolet/ot_dual_variables
Dual variables in EMD_wrapper
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 80 |
1 files changed, 60 insertions, 20 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index de91e74..5c09da2 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -7,14 +7,16 @@ Solvers for the original linear program OT problem # # License: MIT License +import multiprocessing + import numpy as np + # import compiled emd -from .emd_wrap import emd_c, emd2_c +from .emd_wrap import emd_c, check_result from ..utils import parmap -import multiprocessing -def emd(a, b, M, numItermax=100000): +def emd(a, b, M, numItermax=100000, log=False): """Solves the Earth Movers distance problem and returns the OT matrix @@ -42,11 +44,17 @@ def emd(a, b, M, numItermax=100000): numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost and dual + variables. Otherwise returns only the optimal transportation matrix. Returns ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters + log: dict + If input log is true, a dictionary containing the cost and dual + variables and exit status Examples @@ -82,14 +90,24 @@ def emd(a, b, M, numItermax=100000): # if empty array given then use unifor distributions if len(a) == 0: - a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - - return emd_c(a, b, M, numItermax) - - -def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + result_code_string = check_result(result_code) + if log: + log = {} + log['cost'] = cost + log['u'] = u + log['v'] = v + log['warning'] = result_code_string + log['result_code'] = result_code + return G, log + return G + + +def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False, return_matrix=False): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -116,11 +134,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost and dual + variables. Otherwise returns only the optimal transportation cost. + return_matrix: boolean, optional (default=False) + If True, returns the optimal transportation matrix in the log. Returns ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters + log: dict + If input log is true, a dictionary containing the cost and dual + variables and exit status Examples @@ -156,17 +182,31 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): # if empty array given then use unifor distributions if len(a) == 0: - a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - if len(b.shape) == 1: - return emd2_c(a, b, M, numItermax) + if log or return_matrix: + def f(b): + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + result_code_string = check_result(resultCode) + log = {} + if return_matrix: + log['G'] = G + log['u'] = u + log['v'] = v + log['warning'] = result_code_string + log['result_code'] = resultCode + return [cost, log] else: - nb = b.shape[1] - # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)] - def f(b): - return emd2_c(a, b, M, numItermax) - res = parmap(f, [b[:, i] for i in range(nb)], processes) - return np.array(res) + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + check_result(result_code) + return cost + + if len(b.shape) == 1: + return f(b) + nb = b.shape[1] + + res = parmap(f, [b[:, i] for i in range(nb)], processes) + return res |