diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2019-06-27 14:08:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-27 14:08:21 +0200 |
commit | a9b8af146648ee2ae50baf46e69e6281f6b279e4 (patch) | |
tree | c97a7359bdf7de19d7a1cc325304852622a8f580 /ot/lp/emd_wrap.pyx | |
parent | 2364d56aad650d501753cc93a69ea1b8ddf28b0a (diff) | |
parent | 362a7f8fa20cf7ae6f2e36d7e47c7ca9f81d3c51 (diff) |
Merge pull request #89 from rtavenar/master
[MRG] EMD and Wasserstein 1D
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r-- | ot/lp/emd_wrap.pyx | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 83ee6aa..8a4aec9 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -10,7 +10,10 @@ Cython linker with C solver import numpy as np cimport numpy as np +from ..utils import dist + cimport cython +cimport libc.math as math import warnings @@ -93,3 +96,89 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) return G, cost, alpha, beta, result_code + + +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, + np.ndarray[double, ndim=1, mode="c"] v_weights, + np.ndarray[double, ndim=1, mode="c"] u, + np.ndarray[double, ndim=1, mode="c"] v, + str metric='sqeuclidean', + double p=1.): + r""" + Solves the Earth Movers distance problem between sorted 1d measures and + returns the OT matrix and the associated cost + + Parameters + ---------- + u_weights : (ns,) ndarray, float64 + Source histogram + v_weights : (nt,) ndarray, float64 + Target histogram + u : (ns,) ndarray, float64 + Source dirac locations (on the real line) + v : (nt,) ndarray, float64 + Target dirac locations (on the real line) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + + Returns + ------- + gamma: (n, ) ndarray, float64 + Values in the Optimal transportation matrix + indices: (n, 2) ndarray, int64 + Indices of the values stored in gamma for the Optimal transportation + matrix + cost + cost associated to the optimal transportation + """ + cdef double cost = 0. + cdef int n = u_weights.shape[0] + cdef int m = v_weights.shape[0] + + cdef int i = 0 + cdef double w_i = u_weights[0] + cdef int j = 0 + cdef double w_j = v_weights[0] + + cdef double m_ij = 0. + + cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ), + dtype=np.float64) + cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), + dtype=np.int) + cdef int cur_idx = 0 + while i < n and j < m: + if metric == 'sqeuclidean': + m_ij = (u[i] - v[j]) * (u[i] - v[j]) + elif metric == 'cityblock' or metric == 'euclidean': + m_ij = math.fabs(u[i] - v[j]) + elif metric == 'minkowski': + m_ij = math.pow(math.fabs(u[i] - v[j]), p) + else: + m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), + metric=metric)[0, 0] + if w_i < w_j or j == m - 1: + cost += m_ij * w_i + G[cur_idx] = w_i + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j + i += 1 + w_j -= w_i + w_i = u_weights[i] + else: + cost += m_ij * w_j + G[cur_idx] = w_j + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j + j += 1 + w_i -= w_j + w_j = v_weights[j] + cur_idx += 1 + return G[:cur_idx], indices[:cur_idx], cost |