summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/knn.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r--src/python/gudhi/point_cloud/knn.py22
1 files changed, 19 insertions, 3 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 86008bc3..7dc83817 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -8,6 +8,7 @@
# - YYYY/MM Author: Description of the modification
import numpy
+import warnings
# TODO: https://github.com/facebookresearch/faiss
@@ -46,7 +47,7 @@ class KNearestNeighbors:
sort_results (bool): if True, then distances and indices of each point are
sorted on return, so that the first column contains the closest points.
Otherwise, neighbors are returned in an arbitrary order. Defaults to True.
- enable_autodiff (bool): if the input is a torch.tensor, jax.numpy.ndarray or tensorflow.Tensor, this
+ enable_autodiff (bool): if the input is a torch.tensor or tensorflow.Tensor, this
instructs the function to compute distances in a way that works with automatic differentiation.
This is experimental, not supported for all metrics, and requires the package EagerPy.
Defaults to False.
@@ -111,7 +112,7 @@ class KNearestNeighbors:
nargs = {
k: v for k, v in self.params.items() if k in {"p", "n_jobs", "metric_params", "algorithm", "leaf_size"}
}
- self.nn = NearestNeighbors(self.k, metric=self.metric, **nargs)
+ self.nn = NearestNeighbors(n_neighbors=self.k, metric=self.metric, **nargs)
self.nn.fit(X)
if self.params["implementation"] == "hnsw":
@@ -257,6 +258,9 @@ class KNearestNeighbors:
if ef is not None:
self.graph.set_ef(ef)
neighbors, distances = self.graph.knn_query(X, k, num_threads=self.params["num_threads"])
+ with warnings.catch_warnings():
+ if not(numpy.all(numpy.isfinite(distances))):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
# The k nearest neighbors are always sorted. I couldn't find it in the doc, but the code calls searchKnn,
# which returns a priority_queue, and then fills the return array backwards with top/pop on the queue.
if self.return_index:
@@ -290,6 +294,9 @@ class KNearestNeighbors:
if self.return_index:
if self.return_distance:
distances, neighbors = mat.Kmin_argKmin(k, dim=1)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return neighbors, distances
@@ -298,14 +305,23 @@ class KNearestNeighbors:
return neighbors
if self.return_distance:
distances = mat.Kmin(k, dim=1)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return distances
return None
if self.params["implementation"] == "ckdtree":
- qargs = {key: val for key, val in self.params.items() if key in {"p", "eps", "n_jobs"}}
+ qargs = {key: val for key, val in self.params.items() if key in {"p", "eps"}}
+ # SciPy renamed n_jobs to workers
+ qargs["workers"] = self.params.get("workers") or self.params.get("n_jobs") or 1
distances, neighbors = self.kdtree.query(X, k=self.k, **qargs)
+ if k == 1:
+ # SciPy decided to squeeze the last dimension for k=1
+ distances = distances[:, None]
+ neighbors = neighbors[:, None]
if self.return_index:
if self.return_distance:
return neighbors, distances