summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2020-05-06 16:02:31 +0200
committerGitHub <noreply@github.com>2020-05-06 16:02:31 +0200
commitae0470d807c25b0bd730752b996ff9d664ba97f1 (patch)
treea211f4c8acd40240b6991f7fa5b3b845b6c728c3
parent94d5c8cc9046854f473d8e4526a3bcf214eb5411 (diff)
parentea6642c4873b557b4d284f6f3717d8990e23ad51 (diff)
Merge pull request #170 from AdrienCorenflos/master
fix array bounds issue
-rw-r--r--ot/lp/emd_wrap.pyx17
1 files changed, 11 insertions, 6 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index c167964..d79d0ca 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -157,12 +157,12 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
cost associated to the optimal transportation
"""
cdef double cost = 0.
- cdef int n = u_weights.shape[0]
- cdef int m = v_weights.shape[0]
+ cdef Py_ssize_t n = u_weights.shape[0]
+ cdef Py_ssize_t m = v_weights.shape[0]
- cdef int i = 0
+ cdef Py_ssize_t i = 0
cdef double w_i = u_weights[0]
- cdef int j = 0
+ cdef Py_ssize_t j = 0
cdef double w_j = v_weights[0]
cdef double m_ij = 0.
@@ -171,8 +171,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
dtype=np.float64)
cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
dtype=np.int)
- cdef int cur_idx = 0
- while i < n and j < m:
+ cdef Py_ssize_t cur_idx = 0
+ while True:
if metric == 'sqeuclidean':
m_ij = (u[i] - v[j]) * (u[i] - v[j])
elif metric == 'cityblock' or metric == 'euclidean':
@@ -188,6 +188,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
indices[cur_idx, 0] = i
indices[cur_idx, 1] = j
i += 1
+ if i == n:
+ break
w_j -= w_i
w_i = u_weights[i]
else:
@@ -196,7 +198,10 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
indices[cur_idx, 0] = i
indices[cur_idx, 1] = j
j += 1
+ if j == m:
+ break
w_i -= w_j
w_j = v_weights[j]
cur_idx += 1
+ cur_idx += 1
return G[:cur_idx], indices[:cur_idx], cost