diff options
author | Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> | 2021-11-08 15:08:33 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-08 15:08:33 +0100 |
commit | 0c589912800b23609c730871c080ade0c807cdc1 (patch) | |
tree | 0f4fa22f8ad9a8210efea92038af783930a37c6c /ot/utils.py | |
parent | f1628794d521a8dfa00af383b5e06cd6d34af619 (diff) |
[MRG] Distance calculation bug solve (#306)
* solve bug
* Weights & docs
* tests for dist
* test dist
* pep8
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/ot/utils.py b/ot/utils.py index c878563..e6c93c8 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -182,7 +182,7 @@ def euclidean_distances(X, Y, squared=False): return c -def dist(x1, x2=None, metric='sqeuclidean', p=2): +def dist(x1, x2=None, metric='sqeuclidean', p=2, w=None): r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` .. note:: This function is backend-compatible and will work on arrays @@ -202,6 +202,10 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2): 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. + p : float, optional + p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2. + w : array-like, rank 1 + Weights for the weighted metrics. Returns @@ -221,7 +225,9 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2): if not get_backend(x1, x2).__name__ == 'numpy': raise NotImplementedError() else: - return cdist(x1, x2, metric=metric, p=p) + if metric.endswith("minkowski"): + return cdist(x1, x2, metric=metric, p=p, w=w) + return cdist(x1, x2, metric=metric, w=w) def dist0(n, method='lin_square'): |