From f0c5aab988ee966510503a30b0591105594ac67d Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Tue, 14 Apr 2020 15:37:31 +0200 Subject: More testing --- src/python/test/test_dtm.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'src/python/test/test_dtm.py') diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py index 8709dd07..db3e5df5 100755 --- a/src/python/test/test_dtm.py +++ b/src/python/test/test_dtm.py @@ -47,6 +47,13 @@ def test_dtm_compare_euclidean(): assert r6.detach().numpy() == pytest.approx(r0) r6.sum().backward() assert pts2.grad is not None + pts2 = torch.tensor(pts, requires_grad=True) + assert pts2.grad is None + dtm = DistanceToMeasure(k, implementation="ckdtree", enable_autodiff=True) + r7 = dtm.fit_transform(pts2) + assert r7.detach().numpy() == pytest.approx(r0) + r7.sum().backward() + assert pts2.grad is not None def test_dtm_precomputed(): -- cgit v1.2.3