summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/knn.py
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2021-11-03 12:11:14 +0100
committerMathieuCarriere <mathieu.carriere3@gmail.com>2021-11-03 12:11:14 +0100
commit1597a5b4fc1aec9f825e430e80b2a843a9037043 (patch)
tree94bd919d17e6ea220bbddacee831ad1db6326603 /src/python/gudhi/point_cloud/knn.py
parent6b16678c71daa2b9b56cc8fa79a18cde080298cc (diff)
parent728acf3e9ecfba29fc9be7fba5fc88f0a7f49880 (diff)
Merge branch 'master' of https://github.com/GUDHI/gudhi-devel into diff
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r--src/python/gudhi/point_cloud/knn.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 829bf1bf..de5844f9 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -8,6 +8,7 @@
# - YYYY/MM Author: Description of the modification
import numpy
+import warnings
# TODO: https://github.com/facebookresearch/faiss
@@ -257,6 +258,9 @@ class KNearestNeighbors:
if ef is not None:
self.graph.set_ef(ef)
neighbors, distances = self.graph.knn_query(X, k, num_threads=self.params["num_threads"])
+ with warnings.catch_warnings():
+ if not(numpy.all(numpy.isfinite(distances))):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
# The k nearest neighbors are always sorted. I couldn't find it in the doc, but the code calls searchKnn,
# which returns a priority_queue, and then fills the return array backwards with top/pop on the queue.
if self.return_index:
@@ -290,6 +294,9 @@ class KNearestNeighbors:
if self.return_index:
if self.return_distance:
distances, neighbors = mat.Kmin_argKmin(k, dim=1)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return neighbors, distances
@@ -298,6 +305,9 @@ class KNearestNeighbors:
return neighbors
if self.return_distance:
distances = mat.Kmin(k, dim=1)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return distances