summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-24 09:17:54 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-24 09:17:54 +0200
commit71f9b5adfb8d8f4481948391f22e49f45494d071 (patch)
tree8919d83f0e6bc25b3ed5a44b232d0c06d30b4255 /ot
parent9e1d74f44473deb1f4766329bb0d1c8af4dfdd73 (diff)
Added docstrings
Diffstat (limited to 'ot')
-rw-r--r--ot/lp/__init__.py114
1 files changed, 92 insertions, 22 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index e9635a1..a350d60 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -321,8 +321,8 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
.. math::
\gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
- s.t. \gamma 1 = a
- \gamma^T 1= b
+ s.t. \gamma 1 = a,
+ \gamma^T 1= b,
\gamma\geq 0
where :
@@ -330,28 +330,27 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
- x_a and x_b are the samples
- a and b are the sample weights
- Uses the algorithm proposed in [1]_
+ Uses the algorithm detailed in [1]_
Parameters
----------
x_a : (ns,) or (ns, 1) ndarray, float64
- Source histogram (uniform weight if empty list)
+ Source dirac locations (on the real line)
x_b : (nt,) or (ns, 1) ndarray, float64
- Target histogram (uniform weight if empty list)
+ Target dirac locations (on the real line)
a : (ns,) ndarray, float64
Source histogram (uniform weight if empty list)
b : (nt,) ndarray, float64
Target histogram (uniform weight if empty list)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in ... are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
dense: boolean, optional (default=True)
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
Otherwise returns a sparse representation using scipy's `coo_matrix`
- format.
- Due to implementation details, this function runs faster when
+ format. Due to implementation details, this function runs faster when
dense is set to False.
- metric: str, optional (default='sqeuclidean')
- Metric to be used. Has to be a string.
- Due to implementation details, this function runs faster when
- `'sqeuclidean'` or `'euclidean'` metrics are used.
log: boolean, optional (default=False)
If True, returns a dictionary containing the cost.
Otherwise returns only the optimal transportation matrix.
@@ -368,28 +367,28 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
--------
Simple example with obvious solution. The function emd_1d accepts lists and
- perform automatic conversion to numpy arrays
+ performs automatic conversion to numpy arrays
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
- >>> x_a = [0., 2.]
+ >>> x_a = [2., 0.]
>>> x_b = [0., 3.]
- >>> ot.emd_1d(a, b, x_a, x_b)
- array([[ 0.5, 0. ],
- [ 0. , 0.5]])
+ >>> ot.emd_1d(x_a, x_b, a, b)
+ array([[0. , 0.5],
+ [0.5, 0. ]])
References
----------
- .. [1] TODO
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
See Also
--------
ot.lp.emd : EMD for multidimensional distributions
ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
transportation matrix)
-
"""
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
@@ -418,10 +417,9 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
shape=(a.shape[0], b.shape[0]))
if dense:
- G = G.todense()
+ G = G.toarray()
if log:
- log = {}
- log['cost'] = cost
+ log = {'cost': cost}
return G, log
return G
@@ -430,10 +428,82 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
"""Solves the Earth Movers distance problem between 1d measures and returns
the loss
+
+ .. math::
+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
+
+ s.t. \gamma 1 = a,
+ \gamma^T 1= b,
+ \gamma\geq 0
+ where :
+
+ - d is the metric
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ Uses the algorithm detailed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source dirac locations (on the real line)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target dirac locations (on the real line)
+ a : (ns,) ndarray, float64
+ Source histogram (uniform weight if empty list)
+ b : (nt,) ndarray, float64
+ Target histogram (uniform weight if empty list)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in ... are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format. Only used if log is set to True. Due to implementation details,
+ this function runs faster when dense is set to False.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the transportation matrix.
+ Otherwise returns only the loss.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+ log: dict
+ If input log is True, a dictionary containing the Optimal transportation
+ matrix for the given parameters
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd2_1d accepts lists and
+ performs automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [2., 0.]
+ >>> x_b = [0., 3.]
+ >>> ot.emd2_1d(x_a, x_b, a, b)
+ 0.5
+
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ See Also
+ --------
+ ot.lp.emd2 : EMD for multidimensional distributions
+ ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
+ instead of the cost)
"""
# If we do not return G (log==False), then we should not to cast it to dense
# (useless overhead)
- G, log_emd = emd_1d(a=a, b=b, x_a=x_a, x_b=x_b, metric=metric,
+ G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric,
dense=dense and log, log=True)
cost = log_emd['cost']
if log: