diff options
author | Romain Tavenard <romain.tavenard@univ-rennes2.fr> | 2019-06-21 11:21:08 +0200 |
---|---|---|
committer | Romain Tavenard <romain.tavenard@univ-rennes2.fr> | 2019-06-21 11:21:08 +0200 |
commit | 18502d6861a4977cbade957f2e48eeb8dbb55414 (patch) | |
tree | 947cf67b5c118ba6eafd72e38ccb0977085767ca /ot/lp | |
parent | cada9a3019997e8efb95d96c86985110f1e937b9 (diff) |
Sparse G matrix for EMD1d + standard metrics computed without cdist
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/emd_wrap.pyx | 29 |
1 files changed, 21 insertions, 8 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 2966206..ab88d7f 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -101,8 +101,8 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod @cython.wraparound(False) def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, np.ndarray[double, ndim=1, mode="c"] v_weights, - np.ndarray[double, ndim=2, mode="c"] u, - np.ndarray[double, ndim=2, mode="c"] v, + np.ndarray[double, ndim=1, mode="c"] u, + np.ndarray[double, ndim=1, mode="c"] v, str metric='sqeuclidean'): r""" Roro's stuff @@ -118,21 +118,34 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cdef double m_ij = 0. - cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m), + cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ), dtype=np.float64) + cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), + dtype=np.int) + cdef int cur_idx = 0 while i < n and j < m: - m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), - metric=metric)[0, 0] + if metric == 'sqeuclidean': + m_ij = (u[i] - v[j]) ** 2 + elif metric == 'cityblock' or metric == 'euclidean': + m_ij = np.abs(u[i] - v[j]) + else: + m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), + metric=metric)[0, 0] if w_i < w_j or j == m - 1: cost += m_ij * w_i - G[i, j] = w_i + G[cur_idx] = w_i + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j i += 1 w_j -= w_i w_i = u_weights[i] else: cost += m_ij * w_j - G[i, j] = w_j + G[cur_idx] = w_j + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j j += 1 w_i -= w_j w_j = v_weights[j] - return G, cost
\ No newline at end of file + cur_idx += 1 + return G[:cur_idx], indices[:cur_idx], cost |