diff options
author | Romain Tavenard <romain.tavenard@univ-rennes2.fr> | 2019-06-20 14:52:23 +0200 |
---|---|---|
committer | Romain Tavenard <romain.tavenard@univ-rennes2.fr> | 2019-06-20 14:52:23 +0200 |
commit | 15b21611a3a93043d30c4eaaf9d622200453a884 (patch) | |
tree | 20ff23c61f35b876e56d435a5c522249f396e3c5 /ot | |
parent | f63f34f8adb6943b6410f8b773b4b4d8f1c7b4ba (diff) |
EMD 1d without doc made faster
Diffstat (limited to 'ot')
-rw-r--r-- | ot/lp/__init__.py | 5 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 14 |
2 files changed, 13 insertions, 6 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 49ded5b..c4457dc 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -333,9 +333,8 @@ def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False): inv_perm_a = np.argsort(perm_a) inv_perm_b = np.argsort(perm_b) - M = dist(x_a[perm_a], x_b[perm_b], metric=metric) - - G_sorted, cost = emd_1d_sorted(a, b, M) + G_sorted, cost = emd_1d_sorted(a, b, x_a[perm_a], x_b[perm_b], + metric=metric) G = G_sorted[inv_perm_a, :][:, inv_perm_b] if log: log = {} diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index a3d189d..2966206 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -10,6 +10,8 @@ Cython linker with C solver import numpy as np cimport numpy as np +from ..utils import dist + cimport cython import warnings @@ -99,7 +101,9 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod @cython.wraparound(False) def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, np.ndarray[double, ndim=1, mode="c"] v_weights, - np.ndarray[double, ndim=2, mode="c"] M): + np.ndarray[double, ndim=2, mode="c"] u, + np.ndarray[double, ndim=2, mode="c"] v, + str metric='sqeuclidean'): r""" Roro's stuff """ @@ -112,17 +116,21 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cdef int j = 0 cdef double w_j = v_weights[0] + cdef double m_ij = 0. + cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m), dtype=np.float64) while i < n and j < m: + m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), + metric=metric)[0, 0] if w_i < w_j or j == m - 1: - cost += M[i, j] * w_i + cost += m_ij * w_i G[i, j] = w_i i += 1 w_j -= w_i w_i = u_weights[i] else: - cost += M[i, j] * w_j + cost += m_ij * w_j G[i, j] = w_j j += 1 w_i -= w_j |