summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorLeoGautheron <leo_g_autheron@hotmail.fr>2018-07-16 06:48:54 +0200
committerLeoGautheron <leo_g_autheron@hotmail.fr>2018-07-16 06:48:54 +0200
commit73e61546b5509648ba6b7924d6e5a3ebe7fade5b (patch)
treeeb9f526dfc2d1ac1a0e872a8e1cee943f502dbd3 /ot/utils.py
parentcb6bdc516697e3bad6776b897f22c8b6a22f13cd (diff)
Remove dependency sklearn
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py28
1 files changed, 27 insertions, 1 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 5b052ac..14cc805 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -13,7 +13,6 @@ import time
import numpy as np
from scipy.spatial.distance import cdist
-from sklearn.metrics.pairwise import euclidean_distances
import sys
import warnings
try:
@@ -77,6 +76,33 @@ def clean_zeros(a, b, M):
b2 = b[b > 0]
return a2, b2, M2
+def euclidean_distances(X, Y, squared=False):
+ """
+ Considering the rows of X (and Y=X) as vectors, compute the
+ distance matrix between each pair of vectors.
+ Parameters
+ ----------
+ X : {array-like}, shape (n_samples_1, n_features)
+ Y : {array-like}, shape (n_samples_2, n_features)
+ squared : boolean, optional
+ Return squared Euclidean distances.
+ Returns
+ -------
+ distances : {array}, shape (n_samples_1, n_samples_2)
+ """
+ XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
+ YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
+ distances = np.dot(X, Y.T)
+ distances *= -2
+ distances += XX
+ distances += YY
+ np.maximum(distances, 0, out=distances)
+ if X is Y:
+ # Ensure that distances between vectors and themselves are set to 0.0.
+ # This may not be the case due to floating point rounding errors.
+ distances.flat[::distances.shape[0] + 1] = 0.0
+ return distances if squared else np.sqrt(distances, out=distances)
+
def dist(x1, x2=None, metric='sqeuclidean'):
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist