summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
authorAdrienCorenflos <adrien.corenflos@gmail.com>2020-05-05 13:14:35 +0100
committerAdrienCorenflos <adrien.corenflos@gmail.com>2020-05-05 13:14:35 +0100
commitdaec9fe15f9728080b54a7ddbfdb67075e78c6bd (patch)
tree8a3b8cacacf88a5bb0824a81034a1316afba1da6 /ot/lp
parent94d5c8cc9046854f473d8e4526a3bcf214eb5411 (diff)
break before exceeding array size
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/emd_wrap.pyx6
1 files changed, 5 insertions, 1 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index c167964..e9e8fba 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -172,7 +172,7 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
dtype=np.int)
cdef int cur_idx = 0
- while i < n and j < m:
+ while True:
if metric == 'sqeuclidean':
m_ij = (u[i] - v[j]) * (u[i] - v[j])
elif metric == 'cityblock' or metric == 'euclidean':
@@ -188,6 +188,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
indices[cur_idx, 0] = i
indices[cur_idx, 1] = j
i += 1
+ if i == n:
+ break
w_j -= w_i
w_i = u_weights[i]
else:
@@ -196,6 +198,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
indices[cur_idx, 0] = i
indices[cur_idx, 1] = j
j += 1
+ if j == m:
+ break
w_i -= w_j
w_j = v_weights[j]
cur_idx += 1