diff options
-rwxr-xr-x | src/python/test/test_dtm.py | 23 |
1 files changed, 10 insertions, 13 deletions
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py index bdf003a3..09876496 100755 --- a/src/python/test/test_dtm.py +++ b/src/python/test/test_dtm.py @@ -13,7 +13,7 @@ import numpy import pytest import torch import math -#import warnings # used in test_dtm_overflow_warnings +import warnings def test_dtm_compare_euclidean(): @@ -89,16 +89,13 @@ def test_density(): density = DTMDensity(weights=[0.5, 0.5], metric="neighbors", dim=1).fit_transform(distances) assert density == pytest.approx(expected) -# TODO Uncomment this test when next version of pykeops (current is 1.5) is released (should fix the problem (cf. issue #543)) -#def test_dtm_overflow_warnings(): - #pts = numpy.array([[10., 100000000000000000000000000000.], [1000., 100000000000000000000000000.]]) - #impl_warn = ["keops", "hnsw"] +def test_dtm_overflow_warnings(): + pts = numpy.array([[10., 100000000000000000000000000000.], [1000., 100000000000000000000000000.]]) - #with warnings.catch_warnings(record=True) as w: - #for impl in impl_warn: - #dtm = DistanceToMeasure(2, q=10000, implementation=impl) - #r = dtm.fit_transform(pts) - #assert len(w) == 2 - #for i in range(len(w)): - #assert issubclass(w[i].category, RuntimeWarning) - #assert "Overflow" in str(w[i].message) + with warnings.catch_warnings(record=True) as w: + # TODO Test "keops" implementation as well when next version of pykeops (current is 1.5) is released (should fix the problem (cf. issue #543)) + dtm = DistanceToMeasure(2, q=10000, implementation="hnsw") + r = dtm.fit_transform(pts) + assert len(w) == 1 + assert issubclass(w[0].category, RuntimeWarning) + assert "Overflow" in str(w[0].message) |