diff options
author | AdrienCorenflos <adrien.corenflos@gmail.com> | 2020-05-05 13:14:35 +0100 |
---|---|---|
committer | AdrienCorenflos <adrien.corenflos@gmail.com> | 2020-05-05 13:14:35 +0100 |
commit | daec9fe15f9728080b54a7ddbfdb67075e78c6bd (patch) | |
tree | 8a3b8cacacf88a5bb0824a81034a1316afba1da6 /ot | |
parent | 94d5c8cc9046854f473d8e4526a3bcf214eb5411 (diff) |
break before exceeding array size
Diffstat (limited to 'ot')
-rw-r--r-- | ot/lp/emd_wrap.pyx | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index c167964..e9e8fba 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -172,7 +172,7 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), dtype=np.int) cdef int cur_idx = 0 - while i < n and j < m: + while True: if metric == 'sqeuclidean': m_ij = (u[i] - v[j]) * (u[i] - v[j]) elif metric == 'cityblock' or metric == 'euclidean': @@ -188,6 +188,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, indices[cur_idx, 0] = i indices[cur_idx, 1] = j i += 1 + if i == n: + break w_j -= w_i w_i = u_weights[i] else: @@ -196,6 +198,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, indices[cur_idx, 0] = i indices[cur_idx, 1] = j j += 1 + if j == m: + break w_i -= w_j w_j = v_weights[j] cur_idx += 1 |