summaryrefslogtreecommitdiff
path: root/ot/lp/emd_wrap.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r--ot/lp/emd_wrap.pyx9
1 files changed, 5 insertions, 4 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 42e08f4..e5cec89 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -14,13 +14,14 @@ from ..utils import dist
cimport cython
cimport libc.math as math
+from libc.stdint cimport uint64_t
import warnings
cdef extern from "EMD.h":
- int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil
- int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil
+ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
+ int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
@@ -39,7 +40,7 @@ def check_result(result_code):
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads):
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
@@ -75,7 +76,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
target histogram
M : (ns,nt) numpy.ndarray, float64
loss matrix
- max_iter : int
+ max_iter : uint64_t
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.