diff options
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r-- | ot/lp/emd_wrap.pyx | 100 |
1 files changed, 22 insertions, 78 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 26d3330..83ee6aa 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -12,81 +12,33 @@ cimport numpy as np cimport cython +import warnings cdef extern from "EMD.h": - int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) - cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED + int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) + cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED +def check_result(result_code): + if result_code == OPTIMAL: + return None -@cython.boundscheck(False) -@cython.wraparound(False) -def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter): - """ - Solves the Earth Movers distance problem and returns the optimal transport matrix - - gamm=emd(a,b,M) - - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F - - s.t. \gamma 1 = a - - \gamma^T 1= b - - \gamma\geq 0 - where : + if result_code == INFEASIBLE: + message = "Problem infeasible. Check that a and b are in the simplex" + elif result_code == UNBOUNDED: + message = "Problem unbounded" + elif result_code == MAX_ITER_REACHED: + message = "numItermax reached before optimality. Try to increase numItermax." + warnings.warn(message) + return message - - M is the metric cost matrix - - a and b are the sample weights - - Parameters - ---------- - a : (ns,) ndarray, float64 - source histogram - b : (nt,) ndarray, float64 - target histogram - M : (ns,nt) ndarray, float64 - loss matrix - max_iter : int - The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - - - Returns - ------- - gamma: (ns x nt) ndarray - Optimal transportation matrix for the given parameters - - """ - cdef int n1= M.shape[0] - cdef int n2= M.shape[1] - - cdef float cost=0 - cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) - - if not len(a): - a=np.ones((n1,))/n1 - - if not len(b): - b=np.ones((n2,))/n2 - - # calling the function - cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter) - if resultSolver != OPTIMAL: - if resultSolver == INFEASIBLE: - print("Problem infeasible. Try to increase numItermax.") - elif resultSolver == UNBOUNDED: - print("Problem unbounded") - - return G @cython.boundscheck(False) @cython.wraparound(False) -def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter): """ - Solves the Earth Movers distance problem and returns the optimal transport loss + Solves the Earth Movers distance problem and returns the optimal transport matrix gamm=emd(a,b,M) @@ -125,8 +77,11 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo cdef int n1= M.shape[0] cdef int n2= M.shape[1] - cdef float cost=0 + cdef double cost=0 cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) + cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) + if not len(a): a=np.ones((n1,))/n1 @@ -135,17 +90,6 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo b=np.ones((n2,))/n2 # calling the function - cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter) - if resultSolver != OPTIMAL: - if resultSolver == INFEASIBLE: - print("Problem infeasible. Try to inscrease numItermax.") - elif resultSolver == UNBOUNDED: - print("Problem unbounded") - - cost=0 - for i in range(n1): - for j in range(n2): - cost+=G[i,j]*M[i,j] - - return cost + cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) + return G, cost, alpha, beta, result_code |