diff options
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r-- | ot/lp/emd_wrap.pyx | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 5e055fb..2825ba2 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -105,7 +105,33 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, np.ndarray[double, ndim=1, mode="c"] v, str metric='sqeuclidean'): r""" - Roro's stuff + Solves the Earth Movers distance problem between sorted 1d measures and + returns the OT matrix and the associated cost + + Parameters + ---------- + u_weights : (ns,) ndarray, float64 + Source histogram + v_weights : (nt,) ndarray, float64 + Target histogram + u : (ns,) ndarray, float64 + Source dirac locations (on the real line) + v : (nt,) ndarray, float64 + Target dirac locations (on the real line) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. + + Returns + ------- + gamma: (n, ) ndarray, float64 + Values in the Optimal transportation matrix + indices: (n, 2) ndarray, int64 + Indices of the values stored in gamma for the Optimal transportation + matrix + cost + cost associated to the optimal transportation """ cdef double cost = 0. cdef int n = u_weights.shape[0] |