summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-21 18:27:42 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-21 18:27:42 +0200
commit9e1d74f44473deb1f4766329bb0d1c8af4dfdd73 (patch)
treefc5cdc506400bf635ea88a3b21eac3e24d60f620
parent67d3bd4bf0f593aa611d6bf09bbd3a9c883299ba (diff)
Started documenting
-rw-r--r--ot/lp/__init__.py77
-rw-r--r--test/test_ot.py8
2 files changed, 79 insertions, 6 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index decff29..e9635a1 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -313,10 +313,83 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X
-def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
+def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
"""Solves the Earth Movers distance problem between 1d measures and returns
the OT matrix
+
+ .. math::
+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
+
+ s.t. \gamma 1 = a
+ \gamma^T 1= b
+ \gamma\geq 0
+ where :
+
+ - d is the metric
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ Uses the algorithm proposed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source histogram (uniform weight if empty list)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target histogram (uniform weight if empty list)
+ a : (ns,) ndarray, float64
+ Source histogram (uniform weight if empty list)
+ b : (nt,) ndarray, float64
+ Target histogram (uniform weight if empty list)
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format.
+ Due to implementation details, this function runs faster when
+ dense is set to False.
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Has to be a string.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'` or `'euclidean'` metrics are used.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost.
+ Otherwise returns only the optimal transportation matrix.
+
+ Returns
+ -------
+ gamma: (ns, nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is True, a dictionary containing the cost
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd_1d accepts lists and
+ perform automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [0., 2.]
+ >>> x_b = [0., 3.]
+ >>> ot.emd_1d(a, b, x_a, x_b)
+ array([[ 0.5, 0. ],
+ [ 0. , 0.5]])
+
+ References
+ ----------
+
+ .. [1] TODO
+
+ See Also
+ --------
+ ot.lp.emd : EMD for multidimensional distributions
+ ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
+ transportation matrix)
+
"""
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
@@ -353,7 +426,7 @@ def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
return G
-def emd2_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
+def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
"""Solves the Earth Movers distance problem between 1d measures and returns
the loss
diff --git a/test/test_ot.py b/test/test_ot.py
index 2a2e0a5..6d6ea26 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -59,10 +59,10 @@ def test_emd_1d_emd2_1d():
G, log = ot.emd([], [], M, log=True)
wass = log["cost"]
- G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True)
+ G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
wass1d = log["cost"]
- wass1d_emd2 = ot.emd2_1d([], [], u, v, metric='sqeuclidean', log=False)
- wass1d_euc = ot.emd2_1d([], [], u, v, metric='euclidean', log=False)
+ wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
# check loss is similar
np.testing.assert_allclose(wass, wass1d)
@@ -82,7 +82,7 @@ def test_emd_1d_emd2_1d():
# check AssertionError is raised if called on non 1d arrays
u = np.random.randn(n, 2)
v = np.random.randn(m, 2)
- np.testing.assert_raises(AssertionError, ot.emd_1d, [], [], u, v)
+ np.testing.assert_raises(AssertionError, ot.emd_1d, u, v, [], [])
def test_emd_empty():