diff options
author | Antoine Rolet <antoine.rolet@gmail.com> | 2017-09-09 17:28:38 +0900 |
---|---|---|
committer | Antoine Rolet <antoine.rolet@gmail.com> | 2017-09-09 17:28:38 +0900 |
commit | 85c56d96f609c4ad458f0963a068386cc910c66c (patch) | |
tree | 92eae8a636531097e1db6569176cef6e4fd551ac /ot/lp/__init__.py | |
parent | c4aca9e527b60d294fe15d7adce9236f35cb0908 (diff) |
Renamed variables
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 31 |
1 files changed, 17 insertions, 14 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 0f40c19..ab7cb97 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -12,11 +12,11 @@ import multiprocessing import numpy as np # import compiled emd -from .emd_wrap import emd_c, checkResult +from .emd_wrap import emd_c, check_result from ..utils import parmap -def emd(a, b, M, numItermax=100000, log=False): +def emd(a, b, M, num_iter_max=100000, log=False): """Solves the Earth Movers distance problem and returns the OT matrix @@ -41,7 +41,7 @@ def emd(a, b, M, numItermax=100000, log=False): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - numItermax : int, optional (default=100000) + num_iter_max : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. log: boolean, optional (default=False) @@ -54,7 +54,7 @@ def emd(a, b, M, numItermax=100000, log=False): Optimal transportation matrix for the given parameters log: dict If input log is true, a dictionary containing the cost and dual - variables + variables and exit status Examples @@ -94,20 +94,20 @@ def emd(a, b, M, numItermax=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) - resultCodeString = checkResult(resultCode) + G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) + resultCodeString = check_result(result_code) if log: log = {} log['cost'] = cost log['u'] = u log['v'] = v log['warning'] = resultCodeString - log['resultCode'] = resultCode + log['result_code'] = result_code return G, log return G -def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False): +def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -131,7 +131,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log= Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - numItermax : int, optional (default=100000) + num_iter_max : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. @@ -139,6 +139,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log= ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters + log: dict + If input log is true, a dictionary containing the cost and dual + variables and exit status Examples @@ -180,19 +183,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log= if log: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) - resultCodeString = checkResult(resultCode) + G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max) + resultCodeString = check_result(resultCode) log = {} log['G'] = G log['u'] = u log['v'] = v log['warning'] = resultCodeString - log['resultCode'] = resultCode + log['result_code'] = resultCode return [cost, log] else: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) - checkResult(resultCode) + G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) + check_result(result_code) return cost if len(b.shape) == 1: |