diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-04-14 15:37:31 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-04-14 15:37:31 +0200 |
commit | f0c5aab988ee966510503a30b0591105594ac67d (patch) | |
tree | 8aa97c7d0b0026917ec0acf0979c7caa97e80c28 | |
parent | ce75f66da5a2d7ad2c479355112d48817c5ba68b (diff) |
More testing
-rwxr-xr-x | src/python/test/test_dtm.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py index 8709dd07..db3e5df5 100755 --- a/src/python/test/test_dtm.py +++ b/src/python/test/test_dtm.py @@ -47,6 +47,13 @@ def test_dtm_compare_euclidean(): assert r6.detach().numpy() == pytest.approx(r0) r6.sum().backward() assert pts2.grad is not None + pts2 = torch.tensor(pts, requires_grad=True) + assert pts2.grad is None + dtm = DistanceToMeasure(k, implementation="ckdtree", enable_autodiff=True) + r7 = dtm.fit_transform(pts2) + assert r7.detach().numpy() == pytest.approx(r0) + r7.sum().backward() + assert pts2.grad is not None def test_dtm_precomputed(): |