From acb9d5b9d1317d3d8168bc3ac46860d078abba84 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Tue, 14 Apr 2020 20:30:29 +0200 Subject: Check that the gradient is not NaN This can easily happen with pytorch, and there is special code to avoid it. --- src/python/test/test_dtm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src/python') diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py index db3e5df5..de74c42b 100755 --- a/src/python/test/test_dtm.py +++ b/src/python/test/test_dtm.py @@ -46,14 +46,14 @@ def test_dtm_compare_euclidean(): r6 = dtm.fit_transform(pts2) assert r6.detach().numpy() == pytest.approx(r0) r6.sum().backward() - assert pts2.grad is not None + assert pts2.grad is not None and not torch.isnan(pts2.grad).any() pts2 = torch.tensor(pts, requires_grad=True) assert pts2.grad is None dtm = DistanceToMeasure(k, implementation="ckdtree", enable_autodiff=True) r7 = dtm.fit_transform(pts2) assert r7.detach().numpy() == pytest.approx(r0) r7.sum().backward() - assert pts2.grad is not None + assert pts2.grad is not None and not torch.isnan(pts2.grad).any() def test_dtm_precomputed(): -- cgit v1.2.3