summaryrefslogtreecommitdiff
path: root/ot/lp/emd_wrap.pyx
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-27 10:04:35 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-27 10:04:35 +0200
commit1140141938c678d267f688dbb9106d3422d633c5 (patch)
tree4802af0b93cc6702dad7f9e9e7e6f6057e51c11a /ot/lp/emd_wrap.pyx
parent0a039eb07a3ca9ae3c5635cca1719428f62bf67d (diff)
Added minkowski variants and wasserstein_1d functions
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r--ot/lp/emd_wrap.pyx10
1 files changed, 8 insertions, 2 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 2825ba2..7134136 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -103,7 +103,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
np.ndarray[double, ndim=1, mode="c"] v_weights,
np.ndarray[double, ndim=1, mode="c"] u,
np.ndarray[double, ndim=1, mode="c"] v,
- str metric='sqeuclidean'):
+ str metric='sqeuclidean',
+ double p=1.):
r"""
Solves the Earth Movers distance problem between sorted 1d measures and
returns the OT matrix and the associated cost
@@ -121,7 +122,10 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
metric: str, optional (default='sqeuclidean')
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
Due to implementation details, this function runs faster when
- `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
+ are used.
+ p: float, optional (default=1.0)
+ The p-norm to apply for if metric='minkowski'
Returns
-------
@@ -154,6 +158,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
m_ij = (u[i] - v[j]) ** 2
elif metric == 'cityblock' or metric == 'euclidean':
m_ij = abs(u[i] - v[j])
+ elif metric == 'minkowski':
+ m_ij = abs(u[i] - v[j]) ** p
else:
m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
metric=metric)[0, 0]