summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/python/gudhi/point_cloud/knn.py33
-rwxr-xr-xsrc/python/test/test_dtm.py8
-rwxr-xr-xsrc/python/test/test_knn.py6
3 files changed, 43 insertions, 4 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index f6870517..79362c09 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -36,6 +36,9 @@ class KNearestNeighbors:
sort_results (bool): if True, then distances and indices of each point are
sorted on return, so that the first column contains the closest points.
Otherwise, neighbors are returned in an arbitrary order. Defaults to True.
+ enable_autodiff (bool): if the input is a torch.tensor, jax.numpy.array or similar, this instructs
+ the function to compute distances in a way that works with automatic differentiation.
+ This is experimental and not supported for all implementations.
kwargs: additional parameters are forwarded to the backends.
"""
self.k = k
@@ -202,13 +205,18 @@ class KNearestNeighbors:
if self.params["implementation"] == "keops":
import torch
from pykeops.torch import LazyTensor
+ import eagerpy as ep
# 'float64' is slow except on super expensive GPUs. Allow it with some param?
- XX = torch.tensor(X, dtype=torch.float32)
- if X is self.ref_points:
+ queries = X
+ X = ep.astensor(X)
+ XX = torch.as_tensor(X.numpy(), dtype=torch.float32)
+ if queries is self.ref_points:
+ Y = X
YY = XX
else:
- YY = torch.tensor(self.ref_points, dtype=torch.float32)
+ Y = ep.astensor(self.ref_points)
+ YY = torch.as_tensor(Y.numpy(), dtype=torch.float32)
p = self.params["p"]
if p == numpy.inf:
@@ -219,6 +227,24 @@ class KNearestNeighbors:
else:
mat = ((LazyTensor(XX[:, None, :]) - LazyTensor(YY[None, :, :])).abs() ** p).sum(-1)
+ # pykeops does not support autodiff for kmin yet :-(
+ if self.params.get("enable_autodiff", False) and self.return_distance:
+ # Compute the indices of the neighbors, and recompute the relevant distances autodiff-friendly.
+ # Another strategy would be to compute the whole distance matrix with torch.cdist
+ # and use neighbors as indices into it.
+ neighbors = ep.astensor(mat.argKmin(k, dim=1)).numpy()
+ neighbor_pts = Y[neighbors]
+ diff = neighbor_pts - X[:, None, :]
+ if p == numpy.inf:
+ distances = diff.abs().max(-1)
+ elif p == 2:
+ distances = (diff ** 2).sum(-1) ** 0.5
+ else:
+ distances = (diff.abs() ** p).sum(-1) ** (1.0 / p)
+ if self.return_index:
+ return neighbors.raw, distances.raw
+ else:
+ return distances.raw
if self.return_index:
if self.return_distance:
distances, neighbors = mat.Kmin_argKmin(k, dim=1)
@@ -234,7 +260,6 @@ class KNearestNeighbors:
distances = distances ** (1.0 / p)
return distances
return None
- # FIXME: convert everything back to numpy arrays or not?
if self.params["implementation"] == "ckdtree":
qargs = {key: val for key, val in self.params.items() if key in {"p", "eps", "n_jobs"}}
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index bc0d3698..8709dd07 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -11,6 +11,7 @@
from gudhi.point_cloud.dtm import DistanceToMeasure
import numpy
import pytest
+import torch
def test_dtm_compare_euclidean():
@@ -39,6 +40,13 @@ def test_dtm_compare_euclidean():
dtm = DistanceToMeasure(k, implementation="keops")
r5 = dtm.fit_transform(pts)
assert r5 == pytest.approx(r0)
+ pts2 = torch.tensor(pts, requires_grad=True)
+ assert pts2.grad is None
+ dtm = DistanceToMeasure(k, implementation="keops", enable_autodiff=True)
+ r6 = dtm.fit_transform(pts2)
+ assert r6.detach().numpy() == pytest.approx(r0)
+ r6.sum().backward()
+ assert pts2.grad is not None
def test_dtm_precomputed():
diff --git a/src/python/test/test_knn.py b/src/python/test/test_knn.py
index 6269df54..415c9d48 100755
--- a/src/python/test/test_knn.py
+++ b/src/python/test/test_knn.py
@@ -32,6 +32,12 @@ def test_knn_explicit():
.transform(query)
)
assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]]))
+ r = (
+ KNearestNeighbors(2, metric="chebyshev", return_distance=True, return_index=False, implementation="keops", enable_autodiff=True)
+ .fit(base)
+ .transform(query)
+ )
+ assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]]))
knn = KNearestNeighbors(2, metric="minkowski", p=3, return_distance=False, return_index=True)
knn.fit(base)