summaryrefslogtreecommitdiff
path: root/src/python/gudhi
diff options
context:
space:
mode:
authorVincent Rouvreau <10407034+VincentRouvreau@users.noreply.github.com>2021-11-02 17:04:48 +0100
committerGitHub <noreply@github.com>2021-11-02 17:04:48 +0100
commitc9fdd10a487a91601737bef94a1665d7131fd517 (patch)
tree335951b87f8edd36b7eb45cbadc02efc2a6ca707 /src/python/gudhi
parent58c09b0547749604260ea2584b4d2a3a9ea9735b (diff)
parent16044d44e2784508594c4dcb5aa6a0dd52d2c11f (diff)
Merge pull request #525 from Hind-M/knn_dtm_overflow
Warn about overflow/nan in knn
Diffstat (limited to 'src/python/gudhi')
-rw-r--r--src/python/gudhi/point_cloud/knn.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 829bf1bf..de5844f9 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -8,6 +8,7 @@
# - YYYY/MM Author: Description of the modification
import numpy
+import warnings
# TODO: https://github.com/facebookresearch/faiss
@@ -257,6 +258,9 @@ class KNearestNeighbors:
if ef is not None:
self.graph.set_ef(ef)
neighbors, distances = self.graph.knn_query(X, k, num_threads=self.params["num_threads"])
+ with warnings.catch_warnings():
+ if not(numpy.all(numpy.isfinite(distances))):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
# The k nearest neighbors are always sorted. I couldn't find it in the doc, but the code calls searchKnn,
# which returns a priority_queue, and then fills the return array backwards with top/pop on the queue.
if self.return_index:
@@ -290,6 +294,9 @@ class KNearestNeighbors:
if self.return_index:
if self.return_distance:
distances, neighbors = mat.Kmin_argKmin(k, dim=1)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return neighbors, distances
@@ -298,6 +305,9 @@ class KNearestNeighbors:
return neighbors
if self.return_distance:
distances = mat.Kmin(k, dim=1)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return distances