summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-13 01:09:45 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-13 01:10:05 +0200
commit280eb9d2323837619db1ae013b929adb9b45013b (patch)
tree19913f2113c38e9c7783cf5b186dd7a34dd921b2 /src
parent83a1bc1fb6124a35d515f4836d2e830f3dbdf0e7 (diff)
enable_autodiff with keops
This doesn't seem like the best way to handle it, we may want to handle it like a wrapper that gets the indices from knn (whatever backend) and then computes the distances.
Diffstat (limited to 'src')
-rw-r--r--src/python/gudhi/point_cloud/knn.py33
-rwxr-xr-xsrc/python/test/test_dtm.py8
-rwxr-xr-xsrc/python/test/test_knn.py6
3 files changed, 43 insertions, 4 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index f6870517..79362c09 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -36,6 +36,9 @@ class KNearestNeighbors:
sort_results (bool): if True, then distances and indices of each point are
sorted on return, so that the first column contains the closest points.
Otherwise, neighbors are returned in an arbitrary order. Defaults to True.
+ enable_autodiff (bool): if the input is a torch.tensor, jax.numpy.array or similar, this instructs
+ the function to compute distances in a way that works with automatic differentiation.
+ This is experimental and not supported for all implementations.
kwargs: additional parameters are forwarded to the backends.
"""
self.k = k
@@ -202,13 +205,18 @@ class KNearestNeighbors:
if self.params["implementation"] == "keops":
import torch
from pykeops.torch import LazyTensor
+ import eagerpy as ep
# 'float64' is slow except on super expensive GPUs. Allow it with some param?
- XX = torch.tensor(X, dtype=torch.float32)
- if X is self.ref_points:
+ queries = X
+ X = ep.astensor(X)
+ XX = torch.as_tensor(X.numpy(), dtype=torch.float32)
+ if queries is self.ref_points:
+ Y = X
YY = XX
else:
- YY = torch.tensor(self.ref_points, dtype=torch.float32)
+ Y = ep.astensor(self.ref_points)
+ YY = torch.as_tensor(Y.numpy(), dtype=torch.float32)
p = self.params["p"]
if p == numpy.inf:
@@ -219,6 +227,24 @@ class KNearestNeighbors:
else:
mat = ((LazyTensor(XX[:, None, :]) - LazyTensor(YY[None, :, :])).abs() ** p).sum(-1)
+ # pykeops does not support autodiff for kmin yet :-(
+ if self.params.get("enable_autodiff", False) and self.return_distance:
+ # Compute the indices of the neighbors, and recompute the relevant distances autodiff-friendly.
+ # Another strategy would be to compute the whole distance matrix with torch.cdist
+ # and use neighbors as indices into it.
+ neighbors = ep.astensor(mat.argKmin(k, dim=1)).numpy()
+ neighbor_pts = Y[neighbors]
+ diff = neighbor_pts - X[:, None, :]
+ if p == numpy.inf:
+ distances = diff.abs().max(-1)
+ elif p == 2:
+ distances = (diff ** 2).sum(-1) ** 0.5
+ else:
+ distances = (diff.abs() ** p).sum(-1) ** (1.0 / p)
+ if self.return_index:
+ return neighbors.raw, distances.raw
+ else:
+ return distances.raw
if self.return_index:
if self.return_distance:
distances, neighbors = mat.Kmin_argKmin(k, dim=1)
@@ -234,7 +260,6 @@ class KNearestNeighbors:
distances = distances ** (1.0 / p)
return distances
return None
- # FIXME: convert everything back to numpy arrays or not?
if self.params["implementation"] == "ckdtree":
qargs = {key: val for key, val in self.params.items() if key in {"p", "eps", "n_jobs"}}
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index bc0d3698..8709dd07 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -11,6 +11,7 @@
from gudhi.point_cloud.dtm import DistanceToMeasure
import numpy
import pytest
+import torch
def test_dtm_compare_euclidean():
@@ -39,6 +40,13 @@ def test_dtm_compare_euclidean():
dtm = DistanceToMeasure(k, implementation="keops")
r5 = dtm.fit_transform(pts)
assert r5 == pytest.approx(r0)
+ pts2 = torch.tensor(pts, requires_grad=True)
+ assert pts2.grad is None
+ dtm = DistanceToMeasure(k, implementation="keops", enable_autodiff=True)
+ r6 = dtm.fit_transform(pts2)
+ assert r6.detach().numpy() == pytest.approx(r0)
+ r6.sum().backward()
+ assert pts2.grad is not None
def test_dtm_precomputed():
diff --git a/src/python/test/test_knn.py b/src/python/test/test_knn.py
index 6269df54..415c9d48 100755
--- a/src/python/test/test_knn.py
+++ b/src/python/test/test_knn.py
@@ -32,6 +32,12 @@ def test_knn_explicit():
.transform(query)
)
assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]]))
+ r = (
+ KNearestNeighbors(2, metric="chebyshev", return_distance=True, return_index=False, implementation="keops", enable_autodiff=True)
+ .fit(base)
+ .transform(query)
+ )
+ assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]]))
knn = KNearestNeighbors(2, metric="minkowski", p=3, return_distance=False, return_index=True)
knn.fit(base)