summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2021-11-08 15:08:33 +0100
committerGitHub <noreply@github.com>2021-11-08 15:08:33 +0100
commit0c589912800b23609c730871c080ade0c807cdc1 (patch)
tree0f4fa22f8ad9a8210efea92038af783930a37c6c
parentf1628794d521a8dfa00af383b5e06cd6d34af619 (diff)
[MRG] Distance calculation bug solve (#306)
* solve bug * Weights & docs * tests for dist * test dist * pep8
-rw-r--r--ot/utils.py10
-rw-r--r--test/test_utils.py20
2 files changed, 28 insertions, 2 deletions
diff --git a/ot/utils.py b/ot/utils.py
index c878563..e6c93c8 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -182,7 +182,7 @@ def euclidean_distances(X, Y, squared=False):
return c
-def dist(x1, x2=None, metric='sqeuclidean', p=2):
+def dist(x1, x2=None, metric='sqeuclidean', p=2, w=None):
r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`
.. note:: This function is backend-compatible and will work on arrays
@@ -202,6 +202,10 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2):
'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
+ p : float, optional
+ p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2.
+ w : array-like, rank 1
+ Weights for the weighted metrics.
Returns
@@ -221,7 +225,9 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2):
if not get_backend(x1, x2).__name__ == 'numpy':
raise NotImplementedError()
else:
- return cdist(x1, x2, metric=metric, p=p)
+ if metric.endswith("minkowski"):
+ return cdist(x1, x2, metric=metric, p=p, w=w)
+ return cdist(x1, x2, metric=metric, w=w)
def dist0(n, method='lin_square'):
diff --git a/test/test_utils.py b/test/test_utils.py
index 40f4e49..6b476b2 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -117,6 +117,26 @@ def test_dist():
np.testing.assert_allclose(D, D2, atol=1e-14)
np.testing.assert_allclose(D, D3, atol=1e-14)
+ # tests that every metric runs correctly
+ metrics_w = [
+ 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
+ 'euclidean', 'hamming', 'jaccard', 'kulsinski',
+ 'matching', 'minkowski', 'rogerstanimoto', 'russellrao',
+ 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'
+ ] # those that support weights
+ metrics = ['mahalanobis', 'seuclidean'] # do not support weights depending on scipy's version
+
+ for metric in metrics_w:
+ print(metric)
+ ot.dist(x, x, metric=metric, p=3, w=np.random.random((2, )))
+ for metric in metrics:
+ print(metric)
+ ot.dist(x, x, metric=metric, p=3)
+
+ # weighted minkowski but with no weights
+ with pytest.raises(ValueError):
+ ot.dist(x, x, metric="wminkowski")
+
def test_dist_backends(nx):