summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/knn.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-13 20:32:39 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-13 20:32:39 +0200
commit3afce326428dddd638e22ab37ee4b2afe52eba75 (patch)
treed1933e4367583cade47d61d68d2c623ebcbaad01 /src/python/gudhi/point_cloud/knn.py
parent3a86402b733a48d9c25a4995325e72c7438c06c0 (diff)
Generalize enable_autodiff to more implementations
Still limited to L^p
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r--src/python/gudhi/point_cloud/knn.py76
1 files changed, 55 insertions, 21 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 185a7764..87b2798e 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -9,6 +9,7 @@
import numpy
+# TODO: https://github.com/facebookresearch/faiss
class KNearestNeighbors:
"""
@@ -67,6 +68,8 @@ class KNearestNeighbors:
self.params["implementation"] = "ckdtree"
else:
self.params["implementation"] = "sklearn"
+ if not return_distance:
+ self.params["enable_autodiff"] = False
def fit_transform(self, X, y=None):
return self.fit(X).transform(X)
@@ -77,6 +80,10 @@ class KNearestNeighbors:
X (numpy.array): coordinates for reference points.
"""
self.ref_points = X
+ if self.params.get("enable_autodiff", False):
+ import eagerpy as ep
+ if self.params["implementation"] != "keops" or not isinstance(X, ep.PyTorchTensor):
+ X = ep.astensor(X).numpy()
if self.params["implementation"] == "ckdtree":
# sklearn could handle this, but it is much slower
from scipy.spatial import cKDTree
@@ -113,6 +120,41 @@ class KNearestNeighbors:
Args:
X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed".
"""
+ if self.params.get("enable_autodiff", False):
+ # pykeops does not support autodiff for kmin yet, but when it does in the future,
+ # we may want a special path.
+ import eagerpy as ep
+ save_return_index = self.return_index
+ self.return_index = True
+ self.return_distance = False
+ self.params["enable_autodiff"] = False
+ try:
+ # FIXME: how do we test "X is ref_points" then?
+ newX = ep.astensor(X)
+ if self.params["implementation"] != "keops" or not isinstance(newX, ep.PyTorchTensor):
+ newX = newX.numpy()
+ neighbors = self.transform(newX)
+ finally:
+ self.return_index = save_return_index
+ self.return_distance = True
+ self.params["enable_autodiff"] = True
+ # We can implement more later as needed
+ assert self.metric == "minkowski"
+ p = self.params["p"]
+ Y = ep.astensor(self.ref_points)
+ neighbor_pts = Y[neighbors,]
+ diff = neighbor_pts - X[:, None, :]
+ if isinstance(diff, ep.PyTorchTensor):
+ # https://github.com/jonasrauber/eagerpy/issues/6
+ distances = ep.astensor(diff.raw.norm(p, -1))
+ else:
+ distances = diff.norms.lp(p, -1)
+ if self.return_index:
+ return neighbors, distances.raw
+ else:
+ return distances.raw
+
+
metric = self.metric
k = self.k
@@ -207,16 +249,26 @@ class KNearestNeighbors:
from pykeops.torch import LazyTensor
import eagerpy as ep
- # 'float64' is slow except on super expensive GPUs. Allow it with some param?
queries = X
X = ep.astensor(X)
- XX = torch.as_tensor(X.numpy(), dtype=torch.float32)
+ if isinstance(X, ep.PyTorchTensor):
+ XX = X.raw
+ else:
+ # I don't know a clever way to reuse a GPU tensor from tensorflow in pytorch
+ # without copying to/from the CPU.
+ XX = X.numpy()
+ # 'float64' is slow except on super expensive GPUs. Allow it with some param?
+ XX = torch.as_tensor(XX, dtype=torch.float32)
if queries is self.ref_points:
Y = X
YY = XX
else:
Y = ep.astensor(self.ref_points)
- YY = torch.as_tensor(Y.numpy(), dtype=torch.float32)
+ if isinstance(Y, ep.PyTorchTensor):
+ YY = Y.raw
+ else:
+ YY = Y.numpy()
+ YY = torch.as_tensor(YY, dtype=torch.float32)
p = self.params["p"]
if p == numpy.inf:
@@ -227,24 +279,6 @@ class KNearestNeighbors:
else:
mat = ((LazyTensor(XX[:, None, :]) - LazyTensor(YY[None, :, :])).abs() ** p).sum(-1)
- # pykeops does not support autodiff for kmin yet :-(
- if self.params.get("enable_autodiff", False) and self.return_distance:
- # Compute the indices of the neighbors, and recompute the relevant distances autodiff-friendly.
- # Another strategy would be to compute the whole distance matrix with torch.cdist
- # and use neighbors as indices into it.
- neighbors = ep.astensor(mat.argKmin(k, dim=1)).numpy()
- # Work around https://github.com/pytorch/pytorch/issues/34452
- neighbor_pts = Y[neighbors,]
- diff = neighbor_pts - X[:, None, :]
- if isinstance(diff, ep.PyTorchTensor):
- # https://github.com/jonasrauber/eagerpy/issues/6
- distances = ep.astensor(diff.raw.norm(p, -1))
- else:
- distances = diff.norms.lp(p, -1)
- if self.return_index:
- return neighbors, distances.raw
- else:
- return distances.raw
if self.return_index:
if self.return_distance:
distances, neighbors = mat.Kmin_argKmin(k, dim=1)