summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-13 21:38:24 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-13 21:38:24 +0200
commitce75f66da5a2d7ad2c479355112d48817c5ba68b (patch)
tree5e82f7d6b3c36523778a9dc534a0092b62868c4b /src/python/gudhi/point_cloud
parent521d8c17c2b7d71c46a51f0490ff2c13c809fc87 (diff)
Tweak to detect fit_transform
Diffstat (limited to 'src/python/gudhi/point_cloud')
-rw-r--r--src/python/gudhi/point_cloud/knn.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index f2cddb38..8b3cdb46 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -11,6 +11,7 @@ import numpy
# TODO: https://github.com/facebookresearch/faiss
+
class KNearestNeighbors:
"""
Class wrapping several implementations for computing the k nearest neighbors in a point set.
@@ -82,6 +83,7 @@ class KNearestNeighbors:
self.ref_points = X
if self.params.get("enable_autodiff", False):
import eagerpy as ep
+
X = ep.astensor(X)
if self.params["implementation"] != "keops" or not isinstance(X, ep.PyTorchTensor):
# I don't know a clever way to reuse a GPU tensor from tensorflow in pytorch
@@ -127,17 +129,19 @@ class KNearestNeighbors:
# pykeops does not support autodiff for kmin yet, but when it does in the future,
# we may want a special path.
import eagerpy as ep
+
save_return_index = self.return_index
self.return_index = True
self.return_distance = False
self.params["enable_autodiff"] = False
try:
- # FIXME: how do we test "X is ref_points" then?
newX = ep.astensor(X)
- if self.params["implementation"] != "keops" or not isinstance(newX, ep.PyTorchTensor):
+ if self.params["implementation"] != "keops" or (
+ not isinstance(newX, ep.PyTorchTensor) and not isinstance(newX, ep.NumPyTensor)
+ ):
newX = newX.numpy()
else:
- newX = X
+ newX = newX.raw
neighbors = self.transform(newX)
finally:
self.return_index = save_return_index
@@ -159,7 +163,6 @@ class KNearestNeighbors:
else:
return distances.raw
-
metric = self.metric
k = self.k