From 77452dd92f607c3f18a6420cb8cd09fa5cd905a6 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Mon, 24 Jun 2019 13:09:36 +0200 Subject: Added more docstrings (Cython) + fixed link to ot.dist doc --- ot/lp/__init__.py | 4 ++-- ot/lp/emd_wrap.pyx | 28 +++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index a350d60..645ed8b 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -343,7 +343,7 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): b : (nt,) ndarray, float64 Target histogram (uniform weight if empty list) metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in ... are accepted. + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. Due to implementation details, this function runs faster when `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. dense: boolean, optional (default=True) @@ -454,7 +454,7 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): b : (nt,) ndarray, float64 Target histogram (uniform weight if empty list) metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in ... are accepted. + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. Due to implementation details, this function runs faster when `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. dense: boolean, optional (default=True) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 5e055fb..2825ba2 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -105,7 +105,33 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, np.ndarray[double, ndim=1, mode="c"] v, str metric='sqeuclidean'): r""" - Roro's stuff + Solves the Earth Movers distance problem between sorted 1d measures and + returns the OT matrix and the associated cost + + Parameters + ---------- + u_weights : (ns,) ndarray, float64 + Source histogram + v_weights : (nt,) ndarray, float64 + Target histogram + u : (ns,) ndarray, float64 + Source dirac locations (on the real line) + v : (nt,) ndarray, float64 + Target dirac locations (on the real line) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. + + Returns + ------- + gamma: (n, ) ndarray, float64 + Values in the Optimal transportation matrix + indices: (n, 2) ndarray, int64 + Indices of the values stored in gamma for the Optimal transportation + matrix + cost + cost associated to the optimal transportation """ cdef double cost = 0. cdef int n = u_weights.shape[0] -- cgit v1.2.3