diff options
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r-- | ot/lp/emd_wrap.pyx | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 83ee6aa..42b848f 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -10,7 +10,10 @@ Cython linker with C solver import numpy as np cimport numpy as np +from ..utils import dist + cimport cython +cimport libc.math as math import warnings @@ -93,3 +96,89 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) return G, cost, alpha, beta, result_code + + +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, + np.ndarray[double, ndim=1, mode="c"] v_weights, + np.ndarray[double, ndim=1, mode="c"] u, + np.ndarray[double, ndim=1, mode="c"] v, + str metric='sqeuclidean', + double p=1.): + r""" + Solves the Earth Movers distance problem between sorted 1d measures and + returns the OT matrix and the associated cost + + Parameters + ---------- + u_weights : (ns,) ndarray, float64 + Source histogram + v_weights : (nt,) ndarray, float64 + Target histogram + u : (ns,) ndarray, float64 + Source dirac locations (on the real line) + v : (nt,) ndarray, float64 + Target dirac locations (on the real line) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + + Returns + ------- + gamma: (n, ) ndarray, float64 + Values in the Optimal transportation matrix + indices: (n, 2) ndarray, int64 + Indices of the values stored in gamma for the Optimal transportation + matrix + cost + cost associated to the optimal transportation + """ + cdef double cost = 0. + cdef int n = u_weights.shape[0] + cdef int m = v_weights.shape[0] + + cdef int i = 0 + cdef double w_i = u_weights[0] + cdef int j = 0 + cdef double w_j = v_weights[0] + + cdef double m_ij = 0. + + cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ), + dtype=np.float64) + cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), + dtype=np.int) + cdef int cur_idx = 0 + while i < n and j < m: + if metric == 'sqeuclidean': + m_ij = (u[i] - v[j]) ** 2 + elif metric == 'cityblock' or metric == 'euclidean': + m_ij = abs(u[i] - v[j]) + elif metric == 'minkowski': + m_ij = math.pow(abs(u[i] - v[j]), p) + else: + m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), + metric=metric)[0, 0] + if w_i < w_j or j == m - 1: + cost += m_ij * w_i + G[cur_idx] = w_i + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j + i += 1 + w_j -= w_i + w_i = u_weights[i] + else: + cost += m_ij * w_j + G[cur_idx] = w_j + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j + j += 1 + w_i -= w_j + w_j = v_weights[j] + cur_idx += 1 + return G[:cur_idx], indices[:cur_idx], cost |