diff options
author | Antoine Rolet <antoine.rolet@gmail.com> | 2017-07-13 15:30:39 +0900 |
---|---|---|
committer | Antoine Rolet <antoine.rolet@gmail.com> | 2017-07-13 15:30:39 +0900 |
commit | 55a38f8253e5831105d2c329f4d8ed77686d1330 (patch) | |
tree | 6e8890ca1e0817c9043ce3927495c1f8679d5d6c /ot/lp/emd_wrap.pyx | |
parent | 0e86d1bdbc0dcf7ffdb943637f62df5de4612ad0 (diff) |
Added optional maximal number of iteration
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r-- | ot/lp/emd_wrap.pyx | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 46794ab..e8fdba4 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -12,13 +12,13 @@ cimport cython cdef extern from "EMD.h": - void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost) + void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) @cython.boundscheck(False) @cython.wraparound(False) -def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): +def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int maxiter): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -66,13 +66,13 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod b=np.ones((n2,))/n2 # calling the function - EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost) + EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, maxiter) return G @cython.boundscheck(False) @cython.wraparound(False) -def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): +def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int maxiter): """ Solves the Earth Movers distance problem and returns the optimal transport loss @@ -120,7 +120,7 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo b=np.ones((n2,))/n2 # calling the function - EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost) + EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, maxiter) cost=0 for i in range(n1): |