summaryrefslogtreecommitdiff
path: root/ot/lp/emd_wrap.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r--ot/lp/emd_wrap.pyx10
1 files changed, 5 insertions, 5 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index e9e8fba..10bc5cf 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -157,12 +157,12 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
cost associated to the optimal transportation
"""
cdef double cost = 0.
- cdef int n = u_weights.shape[0]
- cdef int m = v_weights.shape[0]
+ cdef Py_ssize_t n = u_weights.shape[0]
+ cdef Py_ssize_t m = v_weights.shape[0]
- cdef int i = 0
+ cdef Py_ssize_t i = 0
cdef double w_i = u_weights[0]
- cdef int j = 0
+ cdef Py_ssize_t j = 0
cdef double w_j = v_weights[0]
cdef double m_ij = 0.
@@ -171,7 +171,7 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
dtype=np.float64)
cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
dtype=np.int)
- cdef int cur_idx = 0
+ cdef Py_ssize_t cur_idx = 0
while True:
if metric == 'sqeuclidean':
m_ij = (u[i] - v[j]) * (u[i] - v[j])