diff options
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 42 |
1 files changed, 23 insertions, 19 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 6e0bdb8..de91e74 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -14,8 +14,7 @@ from ..utils import parmap import multiprocessing - -def emd(a, b, M): +def emd(a, b, M, numItermax=100000): """Solves the Earth Movers distance problem and returns the OT matrix @@ -40,6 +39,9 @@ def emd(a, b, M): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. Returns ------- @@ -52,7 +54,7 @@ def emd(a, b, M): Simple example with obvious solution. The function emd accepts lists and perform automatic conversion to numpy arrays - + >>> import ot >>> a=[.5,.5] >>> b=[.5,.5] @@ -84,10 +86,11 @@ def emd(a, b, M): if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - return emd_c(a, b, M) + return emd_c(a, b, M, numItermax) + -def emd2(a, b, M,processes=multiprocessing.cpu_count()): - """Solves the Earth Movers distance problem and returns the loss +def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): + """Solves the Earth Movers distance problem and returns the loss .. math:: \gamma = arg\min_\gamma <\gamma,M>_F @@ -110,6 +113,9 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. Returns ------- @@ -122,15 +128,15 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()): Simple example with obvious solution. The function emd accepts lists and perform automatic conversion to numpy arrays - - + + >>> import ot >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] >>> ot.emd2(a,b,M) 0.0 - + References ---------- @@ -153,16 +159,14 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()): a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - - if len(b.shape)==1: - return emd2_c(a, b, M) + + if len(b.shape) == 1: + return emd2_c(a, b, M, numItermax) else: - nb=b.shape[1] - #res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)] + nb = b.shape[1] + # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)] + def f(b): - return emd2_c(a,b,M) - res= parmap(f, [b[:,i] for i in range(nb)],processes) + return emd2_c(a, b, M, numItermax) + res = parmap(f, [b[:, i] for i in range(nb)], processes) return np.array(res) - - -
\ No newline at end of file |