summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-14 15:37:31 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-14 15:37:31 +0200
commitf0c5aab988ee966510503a30b0591105594ac67d (patch)
tree8aa97c7d0b0026917ec0acf0979c7caa97e80c28 /src
parentce75f66da5a2d7ad2c479355112d48817c5ba68b (diff)
More testing
Diffstat (limited to 'src')
-rwxr-xr-xsrc/python/test/test_dtm.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index 8709dd07..db3e5df5 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -47,6 +47,13 @@ def test_dtm_compare_euclidean():
assert r6.detach().numpy() == pytest.approx(r0)
r6.sum().backward()
assert pts2.grad is not None
+ pts2 = torch.tensor(pts, requires_grad=True)
+ assert pts2.grad is None
+ dtm = DistanceToMeasure(k, implementation="ckdtree", enable_autodiff=True)
+ r7 = dtm.fit_transform(pts2)
+ assert r7.detach().numpy() == pytest.approx(r0)
+ r7.sum().backward()
+ assert pts2.grad is not None
def test_dtm_precomputed():