diff options
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r-- | ot/lp/emd_wrap.pyx | 27 |
1 files changed, 21 insertions, 6 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 2b6c495..c167964 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -19,7 +19,7 @@ import warnings cdef extern from "EMD.h": - int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) + int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -35,8 +35,7 @@ def check_result(result_code): message = "numItermax reached before optimality. Try to increase numItermax." warnings.warn(message) return message - - + @cython.boundscheck(False) @cython.wraparound(False) def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter): @@ -61,6 +60,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod .. warning:: Note that the M matrix needs to be a C-order :py.cls:`numpy.array` + .. warning:: + The C++ solver discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport + matrix) is exact, the solver only returns feasible dual potentials + on the samples with weights different from zero. + Parameters ---------- a : (ns,) numpy.ndarray, float64 @@ -73,7 +78,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod The maximum number of iterations before stopping the optimization algorithm if it has not converged. - Returns ------- gamma: (ns x nt) numpy.ndarray @@ -82,12 +86,19 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod """ cdef int n1= M.shape[0] cdef int n2= M.shape[1] + cdef int nmax=n1+n2-1 + cdef int result_code = 0 + cdef int nG=0 cdef double cost=0 - cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) + cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0]) + + cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0) + cdef np.ndarray[long, ndim=1, mode="c"] iG=np.zeros(0,dtype=np.int) + cdef np.ndarray[long, ndim=1, mode="c"] jG=np.zeros(0,dtype=np.int) if not len(a): a=np.ones((n1,))/n1 @@ -95,8 +106,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod if not len(b): b=np.ones((n2,))/n2 + # init OT matrix + G=np.zeros([n1, n2]) + # calling the function - cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) + with nogil: + result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) return G, cost, alpha, beta, result_code |