summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xsrc/python/test/test_dtm.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index 8709dd07..db3e5df5 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -47,6 +47,13 @@ def test_dtm_compare_euclidean():
assert r6.detach().numpy() == pytest.approx(r0)
r6.sum().backward()
assert pts2.grad is not None
+ pts2 = torch.tensor(pts, requires_grad=True)
+ assert pts2.grad is None
+ dtm = DistanceToMeasure(k, implementation="ckdtree", enable_autodiff=True)
+ r7 = dtm.fit_transform(pts2)
+ assert r7.detach().numpy() == pytest.approx(r0)
+ r7.sum().backward()
+ assert pts2.grad is not None
def test_dtm_precomputed():