summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2021-12-29 19:26:03 +0100
committerGard Spreemann <gspr@nonempty.org>2021-12-29 19:26:03 +0100
commitedab1c60630f95b38db430017585d06253c92817 (patch)
tree4cb2340c51157da0c81aae0907327417ffddd8ab /ot/utils.py
parent1a283cb0c77f79d6f36de7c01fa61dc8d9696bca (diff)
parent5ed61689a41350fac40ce995515e6cbcb7203f48 (diff)
Merge tag '0.8.1' into dfsg/latest
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/ot/utils.py b/ot/utils.py
index c878563..e6c93c8 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -182,7 +182,7 @@ def euclidean_distances(X, Y, squared=False):
return c
-def dist(x1, x2=None, metric='sqeuclidean', p=2):
+def dist(x1, x2=None, metric='sqeuclidean', p=2, w=None):
r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`
.. note:: This function is backend-compatible and will work on arrays
@@ -202,6 +202,10 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2):
'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
+ p : float, optional
+ p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2.
+ w : array-like, rank 1
+ Weights for the weighted metrics.
Returns
@@ -221,7 +225,9 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2):
if not get_backend(x1, x2).__name__ == 'numpy':
raise NotImplementedError()
else:
- return cdist(x1, x2, metric=metric, p=p)
+ if metric.endswith("minkowski"):
+ return cdist(x1, x2, metric=metric, p=p, w=w)
+ return cdist(x1, x2, metric=metric, w=w)
def dist0(n, method='lin_square'):