diff options
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/__init__.py | 12 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 2 |
2 files changed, 7 insertions, 7 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index ae5b08c..17f5bb4 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -16,7 +16,7 @@ from .emd_wrap import emd_c, check_result from ..utils import parmap -def emd(a, b, M, max_iter=100000, log=False): +def emd(a, b, M, num_iter_max=100000, log=False): """Solves the Earth Movers distance problem and returns the OT matrix @@ -41,7 +41,7 @@ def emd(a, b, M, max_iter=100000, log=False): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - max_iter : int, optional (default=100000) + num_iter_max : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. log: boolean, optional (default=False) @@ -94,7 +94,7 @@ def emd(a, b, M, max_iter=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v, result_code = emd_c(a, b, M, max_iter) + G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) result_code_string = check_result(result_code) if log: log = {} @@ -107,7 +107,7 @@ def emd(a, b, M, max_iter=100000, log=False): return G -def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=False): +def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -183,7 +183,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=Fa if log: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, max_iter) + G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max) result_code_string = check_result(resultCode) log = {} log['G'] = G @@ -194,7 +194,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=Fa return [cost, log] else: def f(b): - G, cost, u, v, result_code = emd_c(a, b, M, max_iter) + G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) check_result(result_code) return cost diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 2fcc0e4..45fc988 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -29,7 +29,7 @@ def check_result(result_code): elif result_code == UNBOUNDED: message = "Problem unbounded" elif result_code == MAX_ITER_REACHED: - message = "max_iter reached before optimality. Try to increase max_iter." + message = "num_iter_max reached before optimality. Try to increase num_iter_max." warnings.warn(message) return message |