summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/knn.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-13 15:01:51 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-13 15:01:51 +0200
commit2f1576a23cf4ac055565875d384ca604c0ff6844 (patch)
treed36bd276d9ffcede0ea0b81c6e5642e5657e3ddb /src/python/gudhi/point_cloud/knn.py
parent280eb9d2323837619db1ae013b929adb9b45013b (diff)
Small autodiff tweaks
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r--src/python/gudhi/point_cloud/knn.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 79362c09..ab3447d4 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -233,16 +233,17 @@ class KNearestNeighbors:
# Another strategy would be to compute the whole distance matrix with torch.cdist
# and use neighbors as indices into it.
neighbors = ep.astensor(mat.argKmin(k, dim=1)).numpy()
- neighbor_pts = Y[neighbors]
+ # Work around https://github.com/pytorch/pytorch/issues/34452
+ neighbor_pts = Y[neighbors,]
diff = neighbor_pts - X[:, None, :]
if p == numpy.inf:
distances = diff.abs().max(-1)
elif p == 2:
- distances = (diff ** 2).sum(-1) ** 0.5
+ distances = (diff ** 2).sum(-1).sqrt()
else:
distances = (diff.abs() ** p).sum(-1) ** (1.0 / p)
if self.return_index:
- return neighbors.raw, distances.raw
+ return neighbors, distances.raw
else:
return distances.raw
if self.return_index: