diff options
-rw-r--r-- | ot/da.py | 2 | ||||
-rw-r--r-- | ot/lp/__init__.py | 12 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 2 | ||||
-rw-r--r-- | test/test_ot.py | 4 |
4 files changed, 10 insertions, 10 deletions
@@ -1370,7 +1370,7 @@ class EMDTransport(BaseTransport): # coupling estimation self.coupling_ = emd( - a=self.mu_s, b=self.mu_t, M=self.cost_, max_iter=self.max_iter + a=self.mu_s, b=self.mu_t, M=self.cost_, num_iter_max=self.max_iter ) return self diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index ae5b08c..17f5bb4 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -16,7 +16,7 @@ from .emd_wrap import emd_c, check_result from ..utils import parmap -def emd(a, b, M, max_iter=100000, log=False): +def emd(a, b, M, num_iter_max=100000, log=False): """Solves the Earth Movers distance problem and returns the OT matrix @@ -41,7 +41,7 @@ def emd(a, b, M, max_iter=100000, log=False): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - max_iter : int, optional (default=100000) + num_iter_max : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. log: boolean, optional (default=False) @@ -94,7 +94,7 @@ def emd(a, b, M, max_iter=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v, result_code = emd_c(a, b, M, max_iter) + G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) result_code_string = check_result(result_code) if log: log = {} @@ -107,7 +107,7 @@ def emd(a, b, M, max_iter=100000, log=False): return G -def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=False): +def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -183,7 +183,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=Fa if log: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, max_iter) + G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max) result_code_string = check_result(resultCode) log = {} log['G'] = G @@ -194,7 +194,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=Fa return [cost, log] else: def f(b): - G, cost, u, v, result_code = emd_c(a, b, M, max_iter) + G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) check_result(result_code) return cost diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 2fcc0e4..45fc988 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -29,7 +29,7 @@ def check_result(result_code): elif result_code == UNBOUNDED: message = "Problem unbounded" elif result_code == MAX_ITER_REACHED: - message = "max_iter reached before optimality. Try to increase max_iter." + message = "num_iter_max reached before optimality. Try to increase num_iter_max." warnings.warn(message) return message diff --git a/test/test_ot.py b/test/test_ot.py index 46fc634..e05e8aa 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -140,8 +140,8 @@ def test_warnings(): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") print('Computing {} EMD '.format(1)) - ot.emd(a, b, M, max_iter=1) - assert "max_iter" in str(w[-1].message) + ot.emd(a, b, M, num_iter_max=1) + assert "num_iter_max" in str(w[-1].message) assert len(w) == 1 a[0] = 100 print('Computing {} EMD '.format(2)) |