diff options
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r-- | ot/lp/emd_wrap.pyx | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 2825ba2..7134136 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -103,7 +103,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, np.ndarray[double, ndim=1, mode="c"] v_weights, np.ndarray[double, ndim=1, mode="c"] u, np.ndarray[double, ndim=1, mode="c"] v, - str metric='sqeuclidean'): + str metric='sqeuclidean', + double p=1.): r""" Solves the Earth Movers distance problem between sorted 1d measures and returns the OT matrix and the associated cost @@ -121,7 +122,10 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, metric: str, optional (default='sqeuclidean') Metric to be used. Only strings listed in :func:`ot.dist` are accepted. Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' Returns ------- @@ -154,6 +158,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, m_ij = (u[i] - v[j]) ** 2 elif metric == 'cityblock' or metric == 'euclidean': m_ij = abs(u[i] - v[j]) + elif metric == 'minkowski': + m_ij = abs(u[i] - v[j]) ** p else: m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), metric=metric)[0, 0] |