summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/python/CMakeLists.txt6
-rw-r--r--src/python/gudhi/point_cloud/dtm.py11
-rw-r--r--src/python/gudhi/point_cloud/knn.py10
-rwxr-xr-xsrc/python/test/test_dtm.py22
4 files changed, 46 insertions, 3 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 96107cfe..68bc15fd 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -541,17 +541,17 @@ if(PYTHONINTERP_FOUND)
endif()
# Tomato
- if(SCIPY_FOUND AND SKLEARN_FOUND)
+ if(SCIPY_FOUND AND SKLEARN_FOUND AND TORCH_FOUND)
add_gudhi_py_test(test_tomato)
endif()
# Weighted Rips
- if(SCIPY_FOUND)
+ if(SCIPY_FOUND AND TORCH_FOUND)
add_gudhi_py_test(test_weighted_rips_complex)
endif()
# DTM Rips
- if(SCIPY_FOUND)
+ if(SCIPY_FOUND AND TORCH_FOUND)
add_gudhi_py_test(test_dtm_rips_complex)
endif()
diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py
index 55ac58e6..96a9e7bf 100644
--- a/src/python/gudhi/point_cloud/dtm.py
+++ b/src/python/gudhi/point_cloud/dtm.py
@@ -9,6 +9,7 @@
from .knn import KNearestNeighbors
import numpy as np
+import warnings
__author__ = "Marc Glisse"
__copyright__ = "Copyright (C) 2020 Inria"
@@ -66,6 +67,11 @@ class DistanceToMeasure:
distances = distances ** self.q
dtm = distances.sum(-1) / self.k
dtm = dtm ** (1.0 / self.q)
+ with warnings.catch_warnings():
+ import torch
+ if isinstance(dtm, torch.Tensor):
+ if not(torch.isfinite(dtm).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'dtm'", RuntimeWarning)
# We compute too many powers, 1/p in knn then q in dtm, 1/q in dtm then q or some log in the caller.
# Add option to skip the final root?
return dtm
@@ -163,6 +169,11 @@ class DTMDensity:
distances = self.knn.transform(X)
distances = distances ** q
dtm = (distances * weights).sum(-1)
+ with warnings.catch_warnings():
+ import torch
+ if isinstance(dtm, torch.Tensor):
+ if not(torch.isfinite(dtm).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'dtm' for density", RuntimeWarning)
if self.normalize:
dtm /= (np.arange(1, k + 1) ** (q / dim) * weights).sum()
density = dtm ** (-dim / q)
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 829bf1bf..de5844f9 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -8,6 +8,7 @@
# - YYYY/MM Author: Description of the modification
import numpy
+import warnings
# TODO: https://github.com/facebookresearch/faiss
@@ -257,6 +258,9 @@ class KNearestNeighbors:
if ef is not None:
self.graph.set_ef(ef)
neighbors, distances = self.graph.knn_query(X, k, num_threads=self.params["num_threads"])
+ with warnings.catch_warnings():
+ if not(numpy.all(numpy.isfinite(distances))):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
# The k nearest neighbors are always sorted. I couldn't find it in the doc, but the code calls searchKnn,
# which returns a priority_queue, and then fills the return array backwards with top/pop on the queue.
if self.return_index:
@@ -290,6 +294,9 @@ class KNearestNeighbors:
if self.return_index:
if self.return_distance:
distances, neighbors = mat.Kmin_argKmin(k, dim=1)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return neighbors, distances
@@ -298,6 +305,9 @@ class KNearestNeighbors:
return neighbors
if self.return_distance:
distances = mat.Kmin(k, dim=1)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return distances
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index 0a52279e..52468d0f 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -13,6 +13,7 @@ import numpy
import pytest
import torch
import math
+import warnings
def test_dtm_compare_euclidean():
@@ -87,3 +88,24 @@ def test_density():
assert density == pytest.approx(expected)
density = DTMDensity(weights=[0.5, 0.5], metric="neighbors", dim=1).fit_transform(distances)
assert density == pytest.approx(expected)
+
+def test_dtm_overflow_warnings():
+ pts = numpy.array([[10., 100000000000000000000000000000.], [1000., 100000000000000000000000000.]])
+ impl_warn = ["keops", "hnsw"]
+
+ with warnings.catch_warnings(record=True) as w:
+ for impl in impl_warn:
+ dtm = DistanceToMeasure(2, q=10000, implementation=impl)
+ r = dtm.fit_transform(pts)
+ assert len(w) == 3
+ for i in range(len(w)):
+ assert issubclass(w[i].category, RuntimeWarning)
+ assert "Overflow" in str(w[i].message)
+
+def test_density_overflow_warning():
+ distances = numpy.array([[10., 100.], [10000000000000., 10.]])
+ with warnings.catch_warnings(record=True) as w:
+ density = DTMDensity(k=2, q=100000, implementation="keops", dim=1).fit_transform(distances)
+ assert len(w) == 1
+ assert issubclass(w[0].category, RuntimeWarning)
+ assert "Overflow" in str(w[0].message)