diff options
author | arolet <antoine.rolet@gmail.com> | 2017-07-21 12:12:48 +0900 |
---|---|---|
committer | arolet <antoine.rolet@gmail.com> | 2017-07-21 12:12:48 +0900 |
commit | 88c62c39a9623e8b58ebb776a9deddc96b43b4a0 (patch) | |
tree | 74506ee862dd8469488a0421c944d53884319465 /ot/lp/emd_wrap.pyx | |
parent | dc3bbd4134f0e2b80e0fe72368bdcf9966f434dc (diff) |
Added dual variables computations
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r-- | ot/lp/emd_wrap.pyx | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index c4ba125..813596f 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -12,7 +12,8 @@ cimport cython cdef extern from "EMD.h": - void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) + void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, + double* alpha, double* beta, int max_iter) @@ -58,6 +59,8 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod cdef float cost=0 cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) + cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) if not len(a): a=np.ones((n1,))/n1 @@ -65,10 +68,13 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod if not len(b): b=np.ones((n2,))/n2 + print alpha.size + print beta.size # calling the function - EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, maxiter) + EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, + <double*> alpha.data, <double*> beta.data, <double*> &cost, maxiter) - return G + return G, alpha, beta @cython.boundscheck(False) @cython.wraparound(False) @@ -112,15 +118,19 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo cdef float cost=0 cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) + + cdef np.ndarray[double, ndim = 1, mode = "c"] alpha = np.zeros([n1]) + cdef np.ndarray[double, ndim = 1, mode = "c"] beta = np.zeros([n2]) if not len(a): a=np.ones((n1,))/n1 if not len(b): b=np.ones((n2,))/n2 - # calling the function - EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, maxiter) + EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, + <double*> alpha.data, <double*> beta.data, <double*> &cost, maxiter) + - return cost + return cost, alpha, beta |