diff options
author | aje <leo_g_autheron@hotmail.fr> | 2017-08-30 09:56:37 +0200 |
---|---|---|
committer | aje <leo_g_autheron@hotmail.fr> | 2017-08-30 09:56:37 +0200 |
commit | 982f36cb0a5f3a6a14454238a26053de7251b0f0 (patch) | |
tree | e7ad62954515430a418b9825829f7d563bf469ff /ot/lp/__init__.py | |
parent | 308ce24b705dfad9d058138d058da8b18002e081 (diff) |
Changes:
- Rename numItermax to max_iter
- Default value to 100000 instead of 10000
- Add max_iter to class SinkhornTransport(BaseTransport)
- Add norm to all BaseTransport
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 46 |
1 files changed, 23 insertions, 23 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 5143a70..7bef648 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -14,8 +14,7 @@ from ..utils import parmap import multiprocessing - -def emd(a, b, M, numItermax=10000): +def emd(a, b, M, max_iter=100000): """Solves the Earth Movers distance problem and returns the OT matrix @@ -40,8 +39,9 @@ def emd(a, b, M, numItermax=10000): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - numItermax : int - Maximum number of iterations made by the LP solver. + max_iter : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. Returns ------- @@ -54,7 +54,7 @@ def emd(a, b, M, numItermax=10000): Simple example with obvious solution. The function emd accepts lists and perform automatic conversion to numpy arrays - + >>> import ot >>> a=[.5,.5] >>> b=[.5,.5] @@ -86,10 +86,11 @@ def emd(a, b, M, numItermax=10000): if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - return emd_c(a, b, M, numItermax) + return emd_c(a, b, M, max_iter) + -def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=10000): - """Solves the Earth Movers distance problem and returns the loss +def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000): + """Solves the Earth Movers distance problem and returns the loss .. math:: \gamma = arg\min_\gamma <\gamma,M>_F @@ -112,8 +113,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=10000): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - numItermax : int - Maximum number of iterations made by the LP solver. + max_iter : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. Returns ------- @@ -126,15 +128,15 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=10000): Simple example with obvious solution. The function emd accepts lists and perform automatic conversion to numpy arrays - - + + >>> import ot >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] >>> ot.emd2(a,b,M) 0.0 - + References ---------- @@ -157,16 +159,14 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=10000): a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - - if len(b.shape)==1: - return emd2_c(a, b, M, numItermax) + + if len(b.shape) == 1: + return emd2_c(a, b, M, max_iter) else: - nb=b.shape[1] - #res=[emd2_c(a,b[:,i].copy(),M, numItermax) for i in range(nb)] + nb = b.shape[1] + # res = [emd2_c(a, b[:, i].copy(), M, max_iter) for i in range(nb)] + def f(b): - return emd2_c(a,b,M, numItermax) - res= parmap(f, [b[:,i] for i in range(nb)],processes) + return emd2_c(a, b, M, max_iter) + res = parmap(f, [b[:, i] for i in range(nb)], processes) return np.array(res) - - -
\ No newline at end of file |