summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/knn.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r--src/python/gudhi/point_cloud/knn.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index ab3447d4..185a7764 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -236,12 +236,11 @@ class KNearestNeighbors:
# Work around https://github.com/pytorch/pytorch/issues/34452
neighbor_pts = Y[neighbors,]
diff = neighbor_pts - X[:, None, :]
- if p == numpy.inf:
- distances = diff.abs().max(-1)
- elif p == 2:
- distances = (diff ** 2).sum(-1).sqrt()
+ if isinstance(diff, ep.PyTorchTensor):
+ # https://github.com/jonasrauber/eagerpy/issues/6
+ distances = ep.astensor(diff.raw.norm(p, -1))
else:
- distances = (diff.abs() ** p).sum(-1) ** (1.0 / p)
+ distances = diff.norms.lp(p, -1)
if self.return_index:
return neighbors, distances.raw
else: