summaryrefslogtreecommitdiff
path: root/ot/lp/emd_wrap.pyx
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-21 11:21:08 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-21 11:21:08 +0200
commit18502d6861a4977cbade957f2e48eeb8dbb55414 (patch)
tree947cf67b5c118ba6eafd72e38ccb0977085767ca /ot/lp/emd_wrap.pyx
parentcada9a3019997e8efb95d96c86985110f1e937b9 (diff)
Sparse G matrix for EMD1d + standard metrics computed without cdist
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r--ot/lp/emd_wrap.pyx29
1 files changed, 21 insertions, 8 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 2966206..ab88d7f 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -101,8 +101,8 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
@cython.wraparound(False)
def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
np.ndarray[double, ndim=1, mode="c"] v_weights,
- np.ndarray[double, ndim=2, mode="c"] u,
- np.ndarray[double, ndim=2, mode="c"] v,
+ np.ndarray[double, ndim=1, mode="c"] u,
+ np.ndarray[double, ndim=1, mode="c"] v,
str metric='sqeuclidean'):
r"""
Roro's stuff
@@ -118,21 +118,34 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
cdef double m_ij = 0.
- cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m),
+ cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ),
dtype=np.float64)
+ cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
+ dtype=np.int)
+ cdef int cur_idx = 0
while i < n and j < m:
- m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
- metric=metric)[0, 0]
+ if metric == 'sqeuclidean':
+ m_ij = (u[i] - v[j]) ** 2
+ elif metric == 'cityblock' or metric == 'euclidean':
+ m_ij = np.abs(u[i] - v[j])
+ else:
+ m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
+ metric=metric)[0, 0]
if w_i < w_j or j == m - 1:
cost += m_ij * w_i
- G[i, j] = w_i
+ G[cur_idx] = w_i
+ indices[cur_idx, 0] = i
+ indices[cur_idx, 1] = j
i += 1
w_j -= w_i
w_i = u_weights[i]
else:
cost += m_ij * w_j
- G[i, j] = w_j
+ G[cur_idx] = w_j
+ indices[cur_idx, 0] = i
+ indices[cur_idx, 1] = j
j += 1
w_i -= w_j
w_j = v_weights[j]
- return G, cost \ No newline at end of file
+ cur_idx += 1
+ return G[:cur_idx], indices[:cur_idx], cost