summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authoraje <leo_g_autheron@hotmail.fr>2017-08-30 09:56:37 +0200
committeraje <leo_g_autheron@hotmail.fr>2017-08-30 09:56:37 +0200
commit982f36cb0a5f3a6a14454238a26053de7251b0f0 (patch)
treee7ad62954515430a418b9825829f7d563bf469ff /ot/lp/__init__.py
parent308ce24b705dfad9d058138d058da8b18002e081 (diff)
Changes:
- Rename numItermax to max_iter - Default value to 100000 instead of 10000 - Add max_iter to class SinkhornTransport(BaseTransport) - Add norm to all BaseTransport
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py46
1 files changed, 23 insertions, 23 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 5143a70..7bef648 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -14,8 +14,7 @@ from ..utils import parmap
import multiprocessing
-
-def emd(a, b, M, numItermax=10000):
+def emd(a, b, M, max_iter=100000):
"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -40,8 +39,9 @@ def emd(a, b, M, numItermax=10000):
Target histogram (uniform weigth if empty list)
M : (ns,nt) ndarray, float64
loss matrix
- numItermax : int
- Maximum number of iterations made by the LP solver.
+ max_iter : int, optional (default=100000)
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
Returns
-------
@@ -54,7 +54,7 @@ def emd(a, b, M, numItermax=10000):
Simple example with obvious solution. The function emd accepts lists and
perform automatic conversion to numpy arrays
-
+
>>> import ot
>>> a=[.5,.5]
>>> b=[.5,.5]
@@ -86,10 +86,11 @@ def emd(a, b, M, numItermax=10000):
if len(b) == 0:
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
- return emd_c(a, b, M, numItermax)
+ return emd_c(a, b, M, max_iter)
+
-def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=10000):
- """Solves the Earth Movers distance problem and returns the loss
+def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000):
+ """Solves the Earth Movers distance problem and returns the loss
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F
@@ -112,8 +113,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=10000):
Target histogram (uniform weigth if empty list)
M : (ns,nt) ndarray, float64
loss matrix
- numItermax : int
- Maximum number of iterations made by the LP solver.
+ max_iter : int, optional (default=100000)
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
Returns
-------
@@ -126,15 +128,15 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=10000):
Simple example with obvious solution. The function emd accepts lists and
perform automatic conversion to numpy arrays
-
-
+
+
>>> import ot
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.emd2(a,b,M)
0.0
-
+
References
----------
@@ -157,16 +159,14 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=10000):
a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
if len(b) == 0:
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
-
- if len(b.shape)==1:
- return emd2_c(a, b, M, numItermax)
+
+ if len(b.shape) == 1:
+ return emd2_c(a, b, M, max_iter)
else:
- nb=b.shape[1]
- #res=[emd2_c(a,b[:,i].copy(),M, numItermax) for i in range(nb)]
+ nb = b.shape[1]
+ # res = [emd2_c(a, b[:, i].copy(), M, max_iter) for i in range(nb)]
+
def f(b):
- return emd2_c(a,b,M, numItermax)
- res= parmap(f, [b[:,i] for i in range(nb)],processes)
+ return emd2_c(a, b, M, max_iter)
+ res = parmap(f, [b[:, i] for i in range(nb)], processes)
return np.array(res)
-
-
- \ No newline at end of file