summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/knn.py
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2021-10-07 15:25:25 +0200
committerHind-M <hind.montassif@gmail.com>2021-10-07 15:25:25 +0200
commitdbdc62a494e54c3dd409a2e80fa169560355ce19 (patch)
treef09ebd703694a5d964f271bd286383627d618f7f /src/python/gudhi/point_cloud/knn.py
parent145fcba2de5f174b8fcdeab5ac1997978ffcdc0d (diff)
Move warnings import to the beginning of knn.py file
Use isfinite instead of isinf and isnan Use catch_warnings context manager instead of "always" with simplefilter
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r--src/python/gudhi/point_cloud/knn.py34
1 files changed, 10 insertions, 24 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 0724ce94..de5844f9 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -8,6 +8,7 @@
# - YYYY/MM Author: Description of the modification
import numpy
+import warnings
# TODO: https://github.com/facebookresearch/faiss
@@ -257,14 +258,9 @@ class KNearestNeighbors:
if ef is not None:
self.graph.set_ef(ef)
neighbors, distances = self.graph.knn_query(X, k, num_threads=self.params["num_threads"])
- if numpy.any(numpy.isnan(distances)):
- import warnings
- warnings.simplefilter("always")
- warnings.warn("NaN value encountered while computing 'distances'", RuntimeWarning)
- if numpy.any(numpy.isinf(distances)):
- import warnings
- warnings.simplefilter("always")
- warnings.warn("Overflow value encountered while computing 'distances'", RuntimeWarning)
+ with warnings.catch_warnings():
+ if not(numpy.all(numpy.isfinite(distances))):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
# The k nearest neighbors are always sorted. I couldn't find it in the doc, but the code calls searchKnn,
# which returns a priority_queue, and then fills the return array backwards with top/pop on the queue.
if self.return_index:
@@ -298,14 +294,9 @@ class KNearestNeighbors:
if self.return_index:
if self.return_distance:
distances, neighbors = mat.Kmin_argKmin(k, dim=1)
- if torch.isnan(distances).any():
- import warnings
- warnings.simplefilter("always")
- warnings.warn("NaN value encountered while computing 'distances'", RuntimeWarning)
- if torch.isinf(distances).any():
- import warnings
- warnings.simplefilter("always")
- warnings.warn("Overflow encountered while computing 'distances'", RuntimeWarning)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return neighbors, distances
@@ -314,14 +305,9 @@ class KNearestNeighbors:
return neighbors
if self.return_distance:
distances = mat.Kmin(k, dim=1)
- if torch.isnan(distances).any():
- import warnings
- warnings.simplefilter("always")
- warnings.warn("NaN value encountered while computing 'distances'", RuntimeWarning)
- if torch.isinf(distances).any():
- import warnings
- warnings.simplefilter("always")
- warnings.warn("Overflow encountered while computing 'distances'", RuntimeWarning)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return distances