summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/lp/emd_wrap.pyx6
1 files changed, 5 insertions, 1 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index c167964..e9e8fba 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -172,7 +172,7 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
dtype=np.int)
cdef int cur_idx = 0
- while i < n and j < m:
+ while True:
if metric == 'sqeuclidean':
m_ij = (u[i] - v[j]) * (u[i] - v[j])
elif metric == 'cityblock' or metric == 'euclidean':
@@ -188,6 +188,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
indices[cur_idx, 0] = i
indices[cur_idx, 1] = j
i += 1
+ if i == n:
+ break
w_j -= w_i
w_i = u_weights[i]
else:
@@ -196,6 +198,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
indices[cur_idx, 0] = i
indices[cur_idx, 1] = j
j += 1
+ if j == m:
+ break
w_i -= w_j
w_j = v_weights[j]
cur_idx += 1