summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 7dac283..5b052ac 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -13,6 +13,7 @@ import time
import numpy as np
from scipy.spatial.distance import cdist
+from sklearn.metrics.pairwise import euclidean_distances
import sys
import warnings
try:
@@ -104,7 +105,8 @@ def dist(x1, x2=None, metric='sqeuclidean'):
"""
if x2 is None:
x2 = x1
-
+ if metric == "sqeuclidean":
+ return euclidean_distances(x1, x2, squared=True)
return cdist(x1, x2, metric=metric)