summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/knn.py
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2021-08-27 11:34:08 +0200
committerHind-M <hind.montassif@gmail.com>2021-08-27 11:34:08 +0200
commit2024c0af61c1b14e50eccfae9a0011cb061b16d2 (patch)
tree062bfb066f7f4550fd5c25dd94369ef91c8ed34b /src/python/gudhi/point_cloud/knn.py
parent19b7d011ee20066ea6895387e0f68d3dd789e0ee (diff)
Fix issue #314
Add overflow and nan warnings in knn when using torch and hnswlib
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r--src/python/gudhi/point_cloud/knn.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 829bf1bf..7a5616e3 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -257,6 +257,12 @@ class KNearestNeighbors:
if ef is not None:
self.graph.set_ef(ef)
neighbors, distances = self.graph.knn_query(X, k, num_threads=self.params["num_threads"])
+ if numpy.any(numpy.isnan(distances)):
+ import warnings
+ warnings.warn("NaN value encountered while computing 'distances'", RuntimeWarning)
+ if numpy.any(numpy.isinf(distances)):
+ import warnings
+ warnings.warn("Overflow value encountered while computing 'distances'", RuntimeWarning)
# The k nearest neighbors are always sorted. I couldn't find it in the doc, but the code calls searchKnn,
# which returns a priority_queue, and then fills the return array backwards with top/pop on the queue.
if self.return_index:
@@ -290,6 +296,12 @@ class KNearestNeighbors:
if self.return_index:
if self.return_distance:
distances, neighbors = mat.Kmin_argKmin(k, dim=1)
+ if distances.isnan().any():
+ import warnings
+ warnings.warn("NaN value encountered while computing 'distances'", RuntimeWarning)
+ if distances.isinf().any():
+ import warnings
+ warnings.warn("Overflow encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return neighbors, distances
@@ -298,6 +310,12 @@ class KNearestNeighbors:
return neighbors
if self.return_distance:
distances = mat.Kmin(k, dim=1)
+ if distances.isnan().any():
+ import warnings
+ warnings.warn("NaN value encountered while computing 'distances'", RuntimeWarning)
+ if distances.isinf().any():
+ import warnings
+ warnings.warn("Overflow encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return distances