diff options
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/ot/utils.py b/ot/utils.py index 7dac283..5b052ac 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -13,6 +13,7 @@ import time import numpy as np from scipy.spatial.distance import cdist +from sklearn.metrics.pairwise import euclidean_distances import sys import warnings try: @@ -104,7 +105,8 @@ def dist(x1, x2=None, metric='sqeuclidean'): """ if x2 is None: x2 = x1 - + if metric == "sqeuclidean": + return euclidean_distances(x1, x2, squared=True) return cdist(x1, x2, metric=metric) |