summaryrefslogtreecommitdiff
path: root/src/python/test
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-14 20:30:29 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-14 20:30:56 +0200
commitacb9d5b9d1317d3d8168bc3ac46860d078abba84 (patch)
tree8cca1567dc12433432f87285fbe0dd1b4d0c97b2 /src/python/test
parent3f1e4bf5f139afe887ae501f20c5d3f40b5a6f79 (diff)
Check that the gradient is not NaN
This can easily happen with pytorch, and there is special code to avoid it.
Diffstat (limited to 'src/python/test')
-rwxr-xr-xsrc/python/test/test_dtm.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index db3e5df5..de74c42b 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -46,14 +46,14 @@ def test_dtm_compare_euclidean():
r6 = dtm.fit_transform(pts2)
assert r6.detach().numpy() == pytest.approx(r0)
r6.sum().backward()
- assert pts2.grad is not None
+ assert pts2.grad is not None and not torch.isnan(pts2.grad).any()
pts2 = torch.tensor(pts, requires_grad=True)
assert pts2.grad is None
dtm = DistanceToMeasure(k, implementation="ckdtree", enable_autodiff=True)
r7 = dtm.fit_transform(pts2)
assert r7.detach().numpy() == pytest.approx(r0)
r7.sum().backward()
- assert pts2.grad is not None
+ assert pts2.grad is not None and not torch.isnan(pts2.grad).any()
def test_dtm_precomputed():