summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/knn.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r--src/python/gudhi/point_cloud/knn.py33
1 files changed, 29 insertions, 4 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index f6870517..79362c09 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -36,6 +36,9 @@ class KNearestNeighbors:
sort_results (bool): if True, then distances and indices of each point are
sorted on return, so that the first column contains the closest points.
Otherwise, neighbors are returned in an arbitrary order. Defaults to True.
+ enable_autodiff (bool): if the input is a torch.tensor, jax.numpy.array or similar, this instructs
+ the function to compute distances in a way that works with automatic differentiation.
+ This is experimental and not supported for all implementations.
kwargs: additional parameters are forwarded to the backends.
"""
self.k = k
@@ -202,13 +205,18 @@ class KNearestNeighbors:
if self.params["implementation"] == "keops":
import torch
from pykeops.torch import LazyTensor
+ import eagerpy as ep
# 'float64' is slow except on super expensive GPUs. Allow it with some param?
- XX = torch.tensor(X, dtype=torch.float32)
- if X is self.ref_points:
+ queries = X
+ X = ep.astensor(X)
+ XX = torch.as_tensor(X.numpy(), dtype=torch.float32)
+ if queries is self.ref_points:
+ Y = X
YY = XX
else:
- YY = torch.tensor(self.ref_points, dtype=torch.float32)
+ Y = ep.astensor(self.ref_points)
+ YY = torch.as_tensor(Y.numpy(), dtype=torch.float32)
p = self.params["p"]
if p == numpy.inf:
@@ -219,6 +227,24 @@ class KNearestNeighbors:
else:
mat = ((LazyTensor(XX[:, None, :]) - LazyTensor(YY[None, :, :])).abs() ** p).sum(-1)
+ # pykeops does not support autodiff for kmin yet :-(
+ if self.params.get("enable_autodiff", False) and self.return_distance:
+ # Compute the indices of the neighbors, and recompute the relevant distances autodiff-friendly.
+ # Another strategy would be to compute the whole distance matrix with torch.cdist
+ # and use neighbors as indices into it.
+ neighbors = ep.astensor(mat.argKmin(k, dim=1)).numpy()
+ neighbor_pts = Y[neighbors]
+ diff = neighbor_pts - X[:, None, :]
+ if p == numpy.inf:
+ distances = diff.abs().max(-1)
+ elif p == 2:
+ distances = (diff ** 2).sum(-1) ** 0.5
+ else:
+ distances = (diff.abs() ** p).sum(-1) ** (1.0 / p)
+ if self.return_index:
+ return neighbors.raw, distances.raw
+ else:
+ return distances.raw
if self.return_index:
if self.return_distance:
distances, neighbors = mat.Kmin_argKmin(k, dim=1)
@@ -234,7 +260,6 @@ class KNearestNeighbors:
distances = distances ** (1.0 / p)
return distances
return None
- # FIXME: convert everything back to numpy arrays or not?
if self.params["implementation"] == "ckdtree":
qargs = {key: val for key, val in self.params.items() if key in {"p", "eps", "n_jobs"}}