diff options
author | Antoine Rolet <antoine.rolet@gmail.com> | 2017-09-05 15:30:50 +0900 |
---|---|---|
committer | Antoine Rolet <antoine.rolet@gmail.com> | 2017-09-05 15:30:50 +0900 |
commit | 13dfb3ddbbd8926b4751b82dd41c5570253b1f07 (patch) | |
tree | b28098e98640c64483a599103e2fdb5df46d2c79 /ot/lp/__init__.py | |
parent | 185eb3e2ef34b5ce6b8f90a28a5bcc78432b7fd3 (diff) | |
parent | 16697047eff9326a0ecb483317c13a854a3d3a71 (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 48 |
1 files changed, 27 insertions, 21 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 915a18c..a14d4e4 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -3,6 +3,10 @@ Solvers for the original linear program OT problem """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + import numpy as np # import compiled emd from .emd_wrap import emd_c, emd2_c @@ -10,8 +14,7 @@ from ..utils import parmap import multiprocessing - -def emd(a, b, M, dual_variables=False, max_iter=-1): +def emd(a, b, M, numItermax=100000, dual_variables=False): """Solves the Earth Movers distance problem and returns the OT matrix @@ -36,6 +39,9 @@ def emd(a, b, M, dual_variables=False, max_iter=-1): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. Returns ------- @@ -48,7 +54,7 @@ def emd(a, b, M, dual_variables=False, max_iter=-1): Simple example with obvious solution. The function emd accepts lists and perform automatic conversion to numpy arrays - + >>> import ot >>> a=[.5,.5] >>> b=[.5,.5] @@ -80,13 +86,13 @@ def emd(a, b, M, dual_variables=False, max_iter=-1): if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - G, alpha, beta = emd_c(a, b, M, max_iter) + G, alpha, beta = emd_c(a, b, M, numItermax) if dual_variables: return G, alpha, beta return G -def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1): - """Solves the Earth Movers distance problem and returns the loss +def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): + """Solves the Earth Movers distance problem and returns the loss .. math:: \gamma = arg\min_\gamma <\gamma,M>_F @@ -109,6 +115,9 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix + numItermax : int, optional (default=100000) + The maximum number of iterations before stopping the optimization + algorithm if it has not converged. Returns ------- @@ -121,15 +130,15 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1): Simple example with obvious solution. The function emd accepts lists and perform automatic conversion to numpy arrays - - + + >>> import ot >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] >>> ot.emd2(a,b,M) 0.0 - + References ---------- @@ -152,16 +161,13 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1): a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - + if len(b.shape)==1: - return emd2_c(a, b, M, max_iter)[0] - else: - nb=b.shape[1] - #res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)] - def f(b): - return emd2_c(a,b,M, max_iter)[0] - res= parmap(f, [b[:,i] for i in range(nb)],processes) - return np.array(res) - - -
\ No newline at end of file + return emd2_c(a, b, M, numItermax)[0] + nb = b.shape[1] + # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)] + + def f(b): + return emd2_c(a,b,M, max_iter)[0] + res= parmap(f, [b[:,i] for i in range(nb)],processes) + return np.array(res) |