summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 6a782e6..0608aee 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -183,7 +183,7 @@ def euclidean_distances(X, Y, squared=False):
return c
-def dist(x1, x2=None, metric='sqeuclidean'):
+def dist(x1, x2=None, metric='sqeuclidean', p=2):
"""Compute distance between samples in x1 and x2
.. note:: This function is backend-compatible and will work on arrays
@@ -222,7 +222,7 @@ def dist(x1, x2=None, metric='sqeuclidean'):
if not get_backend(x1, x2).__name__ == 'numpy':
raise NotImplementedError()
else:
- return cdist(x1, x2, metric=metric)
+ return cdist(x1, x2, metric=metric, p=p)
def dist0(n, method='lin_square'):