diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2017-08-30 15:47:16 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-08-30 15:47:16 +0200 |
commit | 16697047eff9326a0ecb483317c13a854a3d3a71 (patch) | |
tree | b9a8659370286820563a1fd1a9ea09ed0a9003a3 /ot/lp/__init__.py | |
parent | a2ec6e55e458c719484e86a4e6a6e764c2e38dc8 (diff) | |
parent | fadaf2ab3c3844d281b22f8d5c3404c3c4cf7d97 (diff) |
Merge pull request #25 from aje/master
Add iter_max to lp solver and fixes #24
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 42 |
1 files changed, 23 insertions, 19 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 6e0bdb8..de91e74 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -14,8 +14,7 @@ from ..utils import parmap import multiprocessing - -def emd(a, b, M): +def emd(a, b, M, numItermax=100000): """Solves the Earth Movers distance problem and returns the OT matrix @@ -40,6 +39,9 @@ def emd(a, b, M): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. Returns ------- @@ -52,7 +54,7 @@ def emd(a, b, M): Simple example with obvious solution. The function emd accepts lists and perform automatic conversion to numpy arrays - + >>> import ot >>> a=[.5,.5] >>> b=[.5,.5] @@ -84,10 +86,11 @@ def emd(a, b, M): if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - return emd_c(a, b, M) + return emd_c(a, b, M, numItermax) + -def emd2(a, b, M,processes=multiprocessing.cpu_count()): - """Solves the Earth Movers distance problem and returns the loss +def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): + """Solves the Earth Movers distance problem and returns the loss .. math:: \gamma = arg\min_\gamma <\gamma,M>_F @@ -110,6 +113,9 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. Returns ------- @@ -122,15 +128,15 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()): Simple example with obvious solution. The function emd accepts lists and perform automatic conversion to numpy arrays - - + + >>> import ot >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] >>> ot.emd2(a,b,M) 0.0 - + References ---------- @@ -153,16 +159,14 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()): a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - - if len(b.shape)==1: - return emd2_c(a, b, M) + + if len(b.shape) == 1: + return emd2_c(a, b, M, numItermax) else: - nb=b.shape[1] - #res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)] + nb = b.shape[1] + # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)] + def f(b): - return emd2_c(a,b,M) - res= parmap(f, [b[:,i] for i in range(nb)],processes) + return emd2_c(a, b, M, numItermax) + res = parmap(f, [b[:, i] for i in range(nb)], processes) return np.array(res) - - -
\ No newline at end of file |