diff options
author | Nicolas Courty <ncourty@irisa.fr> | 2020-02-28 11:46:51 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-02-28 11:46:51 +0100 |
commit | 00535f30ba5374f95df0a9f0781a12bea37abf63 (patch) | |
tree | ff4d363ce36677928d9d828c26148c3eac07a5da /ot/lp/emd_wrap.pyx | |
parent | e23f4d0646a3e8d28cc146c28574359585295249 (diff) | |
parent | 81e9d425b905b1b7fc0ee888556e60c692d9bb18 (diff) |
Merge branch 'master' into osx-issue
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r-- | ot/lp/emd_wrap.pyx | 150 |
1 files changed, 140 insertions, 10 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 83ee6aa..d345fd4 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -10,13 +10,19 @@ Cython linker with C solver import numpy as np cimport numpy as np +from ..utils import dist + cimport cython +cimport libc.math as math import warnings cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) + int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, + long *iG, long *jG, double *G, long * nG, + double* alpha, double* beta, double *cost, int maxIter) cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -34,9 +40,11 @@ def check_result(result_code): return message + + @cython.boundscheck(False) @cython.wraparound(False) -def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint dense): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -55,33 +63,50 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod - M is the metric cost matrix - a and b are the sample weights + .. warning:: + Note that the M matrix needs to be a C-order :py.cls:`numpy.array` + + .. warning:: + The C++ solver discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport + matrix) is exact, the solver only returns feasible dual potentials + on the samples with weights different from zero. + Parameters ---------- - a : (ns,) ndarray, float64 + a : (ns,) numpy.ndarray, float64 source histogram - b : (nt,) ndarray, float64 + b : (nt,) numpy.ndarray, float64 target histogram - M : (ns,nt) ndarray, float64 + M : (ns,nt) numpy.ndarray, float64 loss matrix max_iter : int The maximum number of iterations before stopping the optimization algorithm if it has not converged. - + dense : bool + Return a sparse transport matrix if set to False Returns ------- - gamma: (ns x nt) ndarray + gamma: (ns x nt) numpy.ndarray Optimal transportation matrix for the given parameters """ cdef int n1= M.shape[0] cdef int n2= M.shape[1] + cdef int nmax=n1+n2-1 + cdef int result_code = 0 + cdef int nG=0 cdef double cost=0 - cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) + cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0]) + + cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0) + cdef np.ndarray[long, ndim=1, mode="c"] iG=np.zeros(0,dtype=np.int) + cdef np.ndarray[long, ndim=1, mode="c"] jG=np.zeros(0,dtype=np.int) if not len(a): a=np.ones((n1,))/n1 @@ -89,7 +114,112 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod if not len(b): b=np.ones((n2,))/n2 - # calling the function - cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) + if dense: + # init OT matrix + G=np.zeros([n1, n2]) + + # calling the function + result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) + + return G, cost, alpha, beta, result_code + + + else: + + # init sparse OT matrix + Gv=np.zeros(nmax) + iG=np.zeros(nmax,dtype=np.int) + jG=np.zeros(nmax,dtype=np.int) + + + result_code = EMD_wrap_return_sparse(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <long*> iG.data, <long*> jG.data, <double*> Gv.data, <long*> &nG, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) - return G, cost, alpha, beta, result_code + + return Gv[:nG], iG[:nG], jG[:nG], cost, alpha, beta, result_code + + + +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, + np.ndarray[double, ndim=1, mode="c"] v_weights, + np.ndarray[double, ndim=1, mode="c"] u, + np.ndarray[double, ndim=1, mode="c"] v, + str metric='sqeuclidean', + double p=1.): + r""" + Solves the Earth Movers distance problem between sorted 1d measures and + returns the OT matrix and the associated cost + + Parameters + ---------- + u_weights : (ns,) ndarray, float64 + Source histogram + v_weights : (nt,) ndarray, float64 + Target histogram + u : (ns,) ndarray, float64 + Source dirac locations (on the real line) + v : (nt,) ndarray, float64 + Target dirac locations (on the real line) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + + Returns + ------- + gamma: (n, ) ndarray, float64 + Values in the Optimal transportation matrix + indices: (n, 2) ndarray, int64 + Indices of the values stored in gamma for the Optimal transportation + matrix + cost + cost associated to the optimal transportation + """ + cdef double cost = 0. + cdef int n = u_weights.shape[0] + cdef int m = v_weights.shape[0] + + cdef int i = 0 + cdef double w_i = u_weights[0] + cdef int j = 0 + cdef double w_j = v_weights[0] + + cdef double m_ij = 0. + + cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ), + dtype=np.float64) + cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), + dtype=np.int) + cdef int cur_idx = 0 + while i < n and j < m: + if metric == 'sqeuclidean': + m_ij = (u[i] - v[j]) * (u[i] - v[j]) + elif metric == 'cityblock' or metric == 'euclidean': + m_ij = math.fabs(u[i] - v[j]) + elif metric == 'minkowski': + m_ij = math.pow(math.fabs(u[i] - v[j]), p) + else: + m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), + metric=metric)[0, 0] + if w_i < w_j or j == m - 1: + cost += m_ij * w_i + G[cur_idx] = w_i + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j + i += 1 + w_j -= w_i + w_i = u_weights[i] + else: + cost += m_ij * w_j + G[cur_idx] = w_j + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j + j += 1 + w_i -= w_j + w_j = v_weights[j] + cur_idx += 1 + return G[:cur_idx], indices[:cur_idx], cost |