summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py48
1 files changed, 27 insertions, 21 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 915a18c..a14d4e4 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -3,6 +3,10 @@
Solvers for the original linear program OT problem
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
# import compiled emd
from .emd_wrap import emd_c, emd2_c
@@ -10,8 +14,7 @@ from ..utils import parmap
import multiprocessing
-
-def emd(a, b, M, dual_variables=False, max_iter=-1):
+def emd(a, b, M, numItermax=100000, dual_variables=False):
"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -36,6 +39,9 @@ def emd(a, b, M, dual_variables=False, max_iter=-1):
Target histogram (uniform weigth if empty list)
M : (ns,nt) ndarray, float64
loss matrix
+ numItermax : int, optional (default=100000)
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
Returns
-------
@@ -48,7 +54,7 @@ def emd(a, b, M, dual_variables=False, max_iter=-1):
Simple example with obvious solution. The function emd accepts lists and
perform automatic conversion to numpy arrays
-
+
>>> import ot
>>> a=[.5,.5]
>>> b=[.5,.5]
@@ -80,13 +86,13 @@ def emd(a, b, M, dual_variables=False, max_iter=-1):
if len(b) == 0:
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
- G, alpha, beta = emd_c(a, b, M, max_iter)
+ G, alpha, beta = emd_c(a, b, M, numItermax)
if dual_variables:
return G, alpha, beta
return G
-def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1):
- """Solves the Earth Movers distance problem and returns the loss
+def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
+ """Solves the Earth Movers distance problem and returns the loss
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F
@@ -109,6 +115,9 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1):
Target histogram (uniform weigth if empty list)
M : (ns,nt) ndarray, float64
loss matrix
+ numItermax : int, optional (default=100000)
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
Returns
-------
@@ -121,15 +130,15 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1):
Simple example with obvious solution. The function emd accepts lists and
perform automatic conversion to numpy arrays
-
-
+
+
>>> import ot
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.emd2(a,b,M)
0.0
-
+
References
----------
@@ -152,16 +161,13 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1):
a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
if len(b) == 0:
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
-
+
if len(b.shape)==1:
- return emd2_c(a, b, M, max_iter)[0]
- else:
- nb=b.shape[1]
- #res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)]
- def f(b):
- return emd2_c(a,b,M, max_iter)[0]
- res= parmap(f, [b[:,i] for i in range(nb)],processes)
- return np.array(res)
-
-
- \ No newline at end of file
+ return emd2_c(a, b, M, numItermax)[0]
+ nb = b.shape[1]
+ # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
+
+ def f(b):
+ return emd2_c(a,b,M, max_iter)[0]
+ res= parmap(f, [b[:,i] for i in range(nb)],processes)
+ return np.array(res)