summaryrefslogtreecommitdiff
path: root/ot/lp/emd_wrap.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r--ot/lp/emd_wrap.pyx28
1 files changed, 27 insertions, 1 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 5e055fb..2825ba2 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -105,7 +105,33 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
np.ndarray[double, ndim=1, mode="c"] v,
str metric='sqeuclidean'):
r"""
- Roro's stuff
+ Solves the Earth Movers distance problem between sorted 1d measures and
+ returns the OT matrix and the associated cost
+
+ Parameters
+ ----------
+ u_weights : (ns,) ndarray, float64
+ Source histogram
+ v_weights : (nt,) ndarray, float64
+ Target histogram
+ u : (ns,) ndarray, float64
+ Source dirac locations (on the real line)
+ v : (nt,) ndarray, float64
+ Target dirac locations (on the real line)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
+
+ Returns
+ -------
+ gamma: (n, ) ndarray, float64
+ Values in the Optimal transportation matrix
+ indices: (n, 2) ndarray, int64
+ Indices of the values stored in gamma for the Optimal transportation
+ matrix
+ cost
+ cost associated to the optimal transportation
"""
cdef double cost = 0.
cdef int n = u_weights.shape[0]