summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-21 18:27:42 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-21 18:27:42 +0200
commit9e1d74f44473deb1f4766329bb0d1c8af4dfdd73 (patch)
treefc5cdc506400bf635ea88a3b21eac3e24d60f620 /ot
parent67d3bd4bf0f593aa611d6bf09bbd3a9c883299ba (diff)
Started documenting
Diffstat (limited to 'ot')
-rw-r--r--ot/lp/__init__.py77
1 files changed, 75 insertions, 2 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index decff29..e9635a1 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -313,10 +313,83 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X
-def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
+def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
"""Solves the Earth Movers distance problem between 1d measures and returns
the OT matrix
+
+ .. math::
+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
+
+ s.t. \gamma 1 = a
+ \gamma^T 1= b
+ \gamma\geq 0
+ where :
+
+ - d is the metric
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ Uses the algorithm proposed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source histogram (uniform weight if empty list)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target histogram (uniform weight if empty list)
+ a : (ns,) ndarray, float64
+ Source histogram (uniform weight if empty list)
+ b : (nt,) ndarray, float64
+ Target histogram (uniform weight if empty list)
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format.
+ Due to implementation details, this function runs faster when
+ dense is set to False.
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Has to be a string.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'` or `'euclidean'` metrics are used.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost.
+ Otherwise returns only the optimal transportation matrix.
+
+ Returns
+ -------
+ gamma: (ns, nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is True, a dictionary containing the cost
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd_1d accepts lists and
+ perform automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [0., 2.]
+ >>> x_b = [0., 3.]
+ >>> ot.emd_1d(a, b, x_a, x_b)
+ array([[ 0.5, 0. ],
+ [ 0. , 0.5]])
+
+ References
+ ----------
+
+ .. [1] TODO
+
+ See Also
+ --------
+ ot.lp.emd : EMD for multidimensional distributions
+ ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
+ transportation matrix)
+
"""
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
@@ -353,7 +426,7 @@ def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
return G
-def emd2_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
+def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
"""Solves the Earth Movers distance problem between 1d measures and returns
the loss