From 06429e5a34790ec51eb1c921293b24c37b81b952 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Sat, 9 Sep 2017 18:23:05 +0900 Subject: Returned to old variable name to follow repo convention --- ot/lp/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'ot/lp/__init__.py') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index ae5b08c..17f5bb4 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -16,7 +16,7 @@ from .emd_wrap import emd_c, check_result from ..utils import parmap -def emd(a, b, M, max_iter=100000, log=False): +def emd(a, b, M, num_iter_max=100000, log=False): """Solves the Earth Movers distance problem and returns the OT matrix @@ -41,7 +41,7 @@ def emd(a, b, M, max_iter=100000, log=False): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - max_iter : int, optional (default=100000) + num_iter_max : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. log: boolean, optional (default=False) @@ -94,7 +94,7 @@ def emd(a, b, M, max_iter=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v, result_code = emd_c(a, b, M, max_iter) + G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) result_code_string = check_result(result_code) if log: log = {} @@ -107,7 +107,7 @@ def emd(a, b, M, max_iter=100000, log=False): return G -def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=False): +def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -183,7 +183,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=Fa if log: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, max_iter) + G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max) result_code_string = check_result(resultCode) log = {} log['G'] = G @@ -194,7 +194,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=Fa return [cost, log] else: def f(b): - G, cost, u, v, result_code = emd_c(a, b, M, max_iter) + G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) check_result(result_code) return cost -- cgit v1.2.3