diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-04-13 01:09:45 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-04-13 01:10:05 +0200 |
commit | 280eb9d2323837619db1ae013b929adb9b45013b (patch) | |
tree | 19913f2113c38e9c7783cf5b186dd7a34dd921b2 /src/python/test/test_dtm.py | |
parent | 83a1bc1fb6124a35d515f4836d2e830f3dbdf0e7 (diff) |
enable_autodiff with keops
This doesn't seem like the best way to handle it, we may want to handle
it like a wrapper that gets the indices from knn (whatever backend) and
then computes the distances.
Diffstat (limited to 'src/python/test/test_dtm.py')
-rwxr-xr-x | src/python/test/test_dtm.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py index bc0d3698..8709dd07 100755 --- a/src/python/test/test_dtm.py +++ b/src/python/test/test_dtm.py @@ -11,6 +11,7 @@ from gudhi.point_cloud.dtm import DistanceToMeasure import numpy import pytest +import torch def test_dtm_compare_euclidean(): @@ -39,6 +40,13 @@ def test_dtm_compare_euclidean(): dtm = DistanceToMeasure(k, implementation="keops") r5 = dtm.fit_transform(pts) assert r5 == pytest.approx(r0) + pts2 = torch.tensor(pts, requires_grad=True) + assert pts2.grad is None + dtm = DistanceToMeasure(k, implementation="keops", enable_autodiff=True) + r6 = dtm.fit_transform(pts2) + assert r6.detach().numpy() == pytest.approx(r0) + r6.sum().backward() + assert pts2.grad is not None def test_dtm_precomputed(): |