summaryrefslogtreecommitdiff
path: root/ot/lp/emd_wrap.pyx
diff options
context:
space:
mode:
authoraje <leo_g_autheron@hotmail.fr>2017-08-30 09:56:37 +0200
committeraje <leo_g_autheron@hotmail.fr>2017-08-30 09:56:37 +0200
commit982f36cb0a5f3a6a14454238a26053de7251b0f0 (patch)
treee7ad62954515430a418b9825829f7d563bf469ff /ot/lp/emd_wrap.pyx
parent308ce24b705dfad9d058138d058da8b18002e081 (diff)
Changes:
- Rename numItermax to max_iter - Default value to 100000 instead of 10000 - Add max_iter to class SinkhornTransport(BaseTransport) - Add norm to all BaseTransport
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r--ot/lp/emd_wrap.pyx82
1 files changed, 42 insertions, 40 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 6039e1f..26d3330 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -15,56 +15,57 @@ cimport cython
cdef extern from "EMD.h":
- int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int numItermax)
+ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter)
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax):
+def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
-
+
gamm=emd(a,b,M)
-
+
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
-
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
+
+ \gamma^T 1= b
+
\gamma\geq 0
where :
-
+
- M is the metric cost matrix
- a and b are the sample weights
-
+
Parameters
----------
a : (ns,) ndarray, float64
- source histogram
+ source histogram
b : (nt,) ndarray, float64
target histogram
M : (ns,nt) ndarray, float64
- loss matrix
- numItermax : int
- Maximum number of iterations made by the LP solver.
-
-
+ loss matrix
+ max_iter : int
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
+
+
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
-
+
"""
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
cdef float cost=0
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
-
+
if not len(a):
a=np.ones((n1,))/n1
@@ -72,7 +73,7 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
b=np.ones((n2,))/n2
# calling the function
- cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, numItermax)
+ cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter)
if resultSolver != OPTIMAL:
if resultSolver == INFEASIBLE:
print("Problem infeasible. Try to increase numItermax.")
@@ -83,49 +84,50 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax):
+def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
"""
Solves the Earth Movers distance problem and returns the optimal transport loss
-
+
gamm=emd(a,b,M)
-
+
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
-
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
+
+ \gamma^T 1= b
+
\gamma\geq 0
where :
-
+
- M is the metric cost matrix
- a and b are the sample weights
-
+
Parameters
----------
a : (ns,) ndarray, float64
- source histogram
+ source histogram
b : (nt,) ndarray, float64
target histogram
M : (ns,nt) ndarray, float64
- loss matrix
- numItermax : int
- Maximum number of iterations made by the LP solver.
-
-
+ loss matrix
+ max_iter : int
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
+
+
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
-
+
"""
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
cdef float cost=0
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
-
+
if not len(a):
a=np.ones((n1,))/n1
@@ -133,13 +135,13 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
b=np.ones((n2,))/n2
# calling the function
- cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, numItermax)
+ cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter)
if resultSolver != OPTIMAL:
if resultSolver == INFEASIBLE:
print("Problem infeasible. Try to inscrease numItermax.")
elif resultSolver == UNBOUNDED:
print("Problem unbounded")
-
+
cost=0
for i in range(n1):
for j in range(n2):