summaryrefslogtreecommitdiff
path: root/ot/lp/emd_wrap.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r--ot/lp/emd_wrap.pyx100
1 files changed, 22 insertions, 78 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 26d3330..83ee6aa 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -12,81 +12,33 @@ cimport numpy as np
cimport cython
+import warnings
cdef extern from "EMD.h":
- int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter)
- cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED
+ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter)
+ cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
+def check_result(result_code):
+ if result_code == OPTIMAL:
+ return None
-@cython.boundscheck(False)
-@cython.wraparound(False)
-def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
- """
- Solves the Earth Movers distance problem and returns the optimal transport matrix
-
- gamm=emd(a,b,M)
-
- .. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
-
- s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
- \gamma\geq 0
- where :
+ if result_code == INFEASIBLE:
+ message = "Problem infeasible. Check that a and b are in the simplex"
+ elif result_code == UNBOUNDED:
+ message = "Problem unbounded"
+ elif result_code == MAX_ITER_REACHED:
+ message = "numItermax reached before optimality. Try to increase numItermax."
+ warnings.warn(message)
+ return message
- - M is the metric cost matrix
- - a and b are the sample weights
-
- Parameters
- ----------
- a : (ns,) ndarray, float64
- source histogram
- b : (nt,) ndarray, float64
- target histogram
- M : (ns,nt) ndarray, float64
- loss matrix
- max_iter : int
- The maximum number of iterations before stopping the optimization
- algorithm if it has not converged.
-
-
- Returns
- -------
- gamma: (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
-
- """
- cdef int n1= M.shape[0]
- cdef int n2= M.shape[1]
-
- cdef float cost=0
- cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
-
- if not len(a):
- a=np.ones((n1,))/n1
-
- if not len(b):
- b=np.ones((n2,))/n2
-
- # calling the function
- cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter)
- if resultSolver != OPTIMAL:
- if resultSolver == INFEASIBLE:
- print("Problem infeasible. Try to increase numItermax.")
- elif resultSolver == UNBOUNDED:
- print("Problem unbounded")
-
- return G
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
"""
- Solves the Earth Movers distance problem and returns the optimal transport loss
+ Solves the Earth Movers distance problem and returns the optimal transport matrix
gamm=emd(a,b,M)
@@ -125,8 +77,11 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
- cdef float cost=0
+ cdef double cost=0
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
+ cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1)
+ cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2)
+
if not len(a):
a=np.ones((n1,))/n1
@@ -135,17 +90,6 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
b=np.ones((n2,))/n2
# calling the function
- cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter)
- if resultSolver != OPTIMAL:
- if resultSolver == INFEASIBLE:
- print("Problem infeasible. Try to inscrease numItermax.")
- elif resultSolver == UNBOUNDED:
- print("Problem unbounded")
-
- cost=0
- for i in range(n1):
- for j in range(n2):
- cost+=G[i,j]*M[i,j]
-
- return cost
+ cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+ return G, cost, alpha, beta, result_code