summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-24 13:09:36 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-24 13:09:36 +0200
commit77452dd92f607c3f18a6420cb8cd09fa5cd905a6 (patch)
tree9d11d69ccfcd19ab4e368cd0df6b49720d6c4346 /ot
parent71f9b5adfb8d8f4481948391f22e49f45494d071 (diff)
Added more docstrings (Cython) + fixed link to ot.dist doc
Diffstat (limited to 'ot')
-rw-r--r--ot/lp/__init__.py4
-rw-r--r--ot/lp/emd_wrap.pyx28
2 files changed, 29 insertions, 3 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index a350d60..645ed8b 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -343,7 +343,7 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
b : (nt,) ndarray, float64
Target histogram (uniform weight if empty list)
metric: str, optional (default='sqeuclidean')
- Metric to be used. Only strings listed in ... are accepted.
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
Due to implementation details, this function runs faster when
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
dense: boolean, optional (default=True)
@@ -454,7 +454,7 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
b : (nt,) ndarray, float64
Target histogram (uniform weight if empty list)
metric: str, optional (default='sqeuclidean')
- Metric to be used. Only strings listed in ... are accepted.
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
Due to implementation details, this function runs faster when
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
dense: boolean, optional (default=True)
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 5e055fb..2825ba2 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -105,7 +105,33 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
np.ndarray[double, ndim=1, mode="c"] v,
str metric='sqeuclidean'):
r"""
- Roro's stuff
+ Solves the Earth Movers distance problem between sorted 1d measures and
+ returns the OT matrix and the associated cost
+
+ Parameters
+ ----------
+ u_weights : (ns,) ndarray, float64
+ Source histogram
+ v_weights : (nt,) ndarray, float64
+ Target histogram
+ u : (ns,) ndarray, float64
+ Source dirac locations (on the real line)
+ v : (nt,) ndarray, float64
+ Target dirac locations (on the real line)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
+
+ Returns
+ -------
+ gamma: (n, ) ndarray, float64
+ Values in the Optimal transportation matrix
+ indices: (n, 2) ndarray, int64
+ Indices of the values stored in gamma for the Optimal transportation
+ matrix
+ cost
+ cost associated to the optimal transportation
"""
cdef double cost = 0.
cdef int n = u_weights.shape[0]