diff options
author | Romain Tavenard <romain.tavenard@univ-rennes2.fr> | 2019-06-27 10:04:35 +0200 |
---|---|---|
committer | Romain Tavenard <romain.tavenard@univ-rennes2.fr> | 2019-06-27 10:04:35 +0200 |
commit | 1140141938c678d267f688dbb9106d3422d633c5 (patch) | |
tree | 4802af0b93cc6702dad7f9e9e7e6f6057e51c11a /ot/lp/emd_wrap.pyx | |
parent | 0a039eb07a3ca9ae3c5635cca1719428f62bf67d (diff) |
Added minkowski variants and wasserstein_1d functions
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r-- | ot/lp/emd_wrap.pyx | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 2825ba2..7134136 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -103,7 +103,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, np.ndarray[double, ndim=1, mode="c"] v_weights, np.ndarray[double, ndim=1, mode="c"] u, np.ndarray[double, ndim=1, mode="c"] v, - str metric='sqeuclidean'): + str metric='sqeuclidean', + double p=1.): r""" Solves the Earth Movers distance problem between sorted 1d measures and returns the OT matrix and the associated cost @@ -121,7 +122,10 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, metric: str, optional (default='sqeuclidean') Metric to be used. Only strings listed in :func:`ot.dist` are accepted. Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' Returns ------- @@ -154,6 +158,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, m_ij = (u[i] - v[j]) ** 2 elif metric == 'cityblock' or metric == 'euclidean': m_ij = abs(u[i] - v[j]) + elif metric == 'minkowski': + m_ij = abs(u[i] - v[j]) ** p else: m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), metric=metric)[0, 0] |