summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/knn.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r--src/python/gudhi/point_cloud/knn.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 4652fe80..de5844f9 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -8,6 +8,7 @@
# - YYYY/MM Author: Description of the modification
import numpy
+import warnings
# TODO: https://github.com/facebookresearch/faiss
@@ -46,7 +47,7 @@ class KNearestNeighbors:
sort_results (bool): if True, then distances and indices of each point are
sorted on return, so that the first column contains the closest points.
Otherwise, neighbors are returned in an arbitrary order. Defaults to True.
- enable_autodiff (bool): if the input is a torch.tensor, jax.numpy.ndarray or tensorflow.Tensor, this
+ enable_autodiff (bool): if the input is a torch.tensor or tensorflow.Tensor, this
instructs the function to compute distances in a way that works with automatic differentiation.
This is experimental, not supported for all metrics, and requires the package EagerPy.
Defaults to False.
@@ -111,7 +112,7 @@ class KNearestNeighbors:
nargs = {
k: v for k, v in self.params.items() if k in {"p", "n_jobs", "metric_params", "algorithm", "leaf_size"}
}
- self.nn = NearestNeighbors(self.k, metric=self.metric, **nargs)
+ self.nn = NearestNeighbors(n_neighbors=self.k, metric=self.metric, **nargs)
self.nn.fit(X)
if self.params["implementation"] == "hnsw":
@@ -257,6 +258,9 @@ class KNearestNeighbors:
if ef is not None:
self.graph.set_ef(ef)
neighbors, distances = self.graph.knn_query(X, k, num_threads=self.params["num_threads"])
+ with warnings.catch_warnings():
+ if not(numpy.all(numpy.isfinite(distances))):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
# The k nearest neighbors are always sorted. I couldn't find it in the doc, but the code calls searchKnn,
# which returns a priority_queue, and then fills the return array backwards with top/pop on the queue.
if self.return_index:
@@ -290,6 +294,9 @@ class KNearestNeighbors:
if self.return_index:
if self.return_distance:
distances, neighbors = mat.Kmin_argKmin(k, dim=1)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return neighbors, distances
@@ -298,6 +305,9 @@ class KNearestNeighbors:
return neighbors
if self.return_distance:
distances = mat.Kmin(k, dim=1)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return distances