summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2018-07-18 11:34:37 +0200
committerGitHub <noreply@github.com>2018-07-18 11:34:37 +0200
commit5cd6c0aae23a36fe27c188cdd18c7f0fba8a0360 (patch)
treed7b0968d9d50f0e40d225aeddc85b10d8e6c4cca /ot/utils.py
parent7c5c8803b2bdb67545783db3321b9d5a81a063d6 (diff)
parent0764e356325df7e18f72c0ff468bfa8f8ee35059 (diff)
Merge pull request #57 from LeoGautheron/master
Speed-up Sinkhorn
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py31
1 files changed, 30 insertions, 1 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 7dac283..bb21b38 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -77,6 +77,34 @@ def clean_zeros(a, b, M):
return a2, b2, M2
+def euclidean_distances(X, Y, squared=False):
+ """
+ Considering the rows of X (and Y=X) as vectors, compute the
+ distance matrix between each pair of vectors.
+ Parameters
+ ----------
+ X : {array-like}, shape (n_samples_1, n_features)
+ Y : {array-like}, shape (n_samples_2, n_features)
+ squared : boolean, optional
+ Return squared Euclidean distances.
+ Returns
+ -------
+ distances : {array}, shape (n_samples_1, n_samples_2)
+ """
+ XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
+ YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
+ distances = np.dot(X, Y.T)
+ distances *= -2
+ distances += XX
+ distances += YY
+ np.maximum(distances, 0, out=distances)
+ if X is Y:
+ # Ensure that distances between vectors and themselves are set to 0.0.
+ # This may not be the case due to floating point rounding errors.
+ distances.flat[::distances.shape[0] + 1] = 0.0
+ return distances if squared else np.sqrt(distances, out=distances)
+
+
def dist(x1, x2=None, metric='sqeuclidean'):
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
@@ -104,7 +132,8 @@ def dist(x1, x2=None, metric='sqeuclidean'):
"""
if x2 is None:
x2 = x1
-
+ if metric == "sqeuclidean":
+ return euclidean_distances(x1, x2, squared=True)
return cdist(x1, x2, metric=metric)