summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-13 01:09:45 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-13 01:10:05 +0200
commit280eb9d2323837619db1ae013b929adb9b45013b (patch)
tree19913f2113c38e9c7783cf5b186dd7a34dd921b2 /src/python/gudhi/point_cloud
parent83a1bc1fb6124a35d515f4836d2e830f3dbdf0e7 (diff)
enable_autodiff with keops
This doesn't seem like the best way to handle it, we may want to handle it like a wrapper that gets the indices from knn (whatever backend) and then computes the distances.
Diffstat (limited to 'src/python/gudhi/point_cloud')
-rw-r--r--src/python/gudhi/point_cloud/knn.py33
1 files changed, 29 insertions, 4 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index f6870517..79362c09 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -36,6 +36,9 @@ class KNearestNeighbors:
sort_results (bool): if True, then distances and indices of each point are
sorted on return, so that the first column contains the closest points.
Otherwise, neighbors are returned in an arbitrary order. Defaults to True.
+ enable_autodiff (bool): if the input is a torch.tensor, jax.numpy.array or similar, this instructs
+ the function to compute distances in a way that works with automatic differentiation.
+ This is experimental and not supported for all implementations.
kwargs: additional parameters are forwarded to the backends.
"""
self.k = k
@@ -202,13 +205,18 @@ class KNearestNeighbors:
if self.params["implementation"] == "keops":
import torch
from pykeops.torch import LazyTensor
+ import eagerpy as ep
# 'float64' is slow except on super expensive GPUs. Allow it with some param?
- XX = torch.tensor(X, dtype=torch.float32)
- if X is self.ref_points:
+ queries = X
+ X = ep.astensor(X)
+ XX = torch.as_tensor(X.numpy(), dtype=torch.float32)
+ if queries is self.ref_points:
+ Y = X
YY = XX
else:
- YY = torch.tensor(self.ref_points, dtype=torch.float32)
+ Y = ep.astensor(self.ref_points)
+ YY = torch.as_tensor(Y.numpy(), dtype=torch.float32)
p = self.params["p"]
if p == numpy.inf:
@@ -219,6 +227,24 @@ class KNearestNeighbors:
else:
mat = ((LazyTensor(XX[:, None, :]) - LazyTensor(YY[None, :, :])).abs() ** p).sum(-1)
+ # pykeops does not support autodiff for kmin yet :-(
+ if self.params.get("enable_autodiff", False) and self.return_distance:
+ # Compute the indices of the neighbors, and recompute the relevant distances autodiff-friendly.
+ # Another strategy would be to compute the whole distance matrix with torch.cdist
+ # and use neighbors as indices into it.
+ neighbors = ep.astensor(mat.argKmin(k, dim=1)).numpy()
+ neighbor_pts = Y[neighbors]
+ diff = neighbor_pts - X[:, None, :]
+ if p == numpy.inf:
+ distances = diff.abs().max(-1)
+ elif p == 2:
+ distances = (diff ** 2).sum(-1) ** 0.5
+ else:
+ distances = (diff.abs() ** p).sum(-1) ** (1.0 / p)
+ if self.return_index:
+ return neighbors.raw, distances.raw
+ else:
+ return distances.raw
if self.return_index:
if self.return_distance:
distances, neighbors = mat.Kmin_argKmin(k, dim=1)
@@ -234,7 +260,6 @@ class KNearestNeighbors:
distances = distances ** (1.0 / p)
return distances
return None
- # FIXME: convert everything back to numpy arrays or not?
if self.params["implementation"] == "ckdtree":
qargs = {key: val for key, val in self.params.items() if key in {"p", "eps", "n_jobs"}}