summaryrefslogtreecommitdiff
path: root/src/python/test/test_dtm.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-13 01:09:45 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-13 01:10:05 +0200
commit280eb9d2323837619db1ae013b929adb9b45013b (patch)
tree19913f2113c38e9c7783cf5b186dd7a34dd921b2 /src/python/test/test_dtm.py
parent83a1bc1fb6124a35d515f4836d2e830f3dbdf0e7 (diff)
enable_autodiff with keops
This doesn't seem like the best way to handle it, we may want to handle it like a wrapper that gets the indices from knn (whatever backend) and then computes the distances.
Diffstat (limited to 'src/python/test/test_dtm.py')
-rwxr-xr-xsrc/python/test/test_dtm.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index bc0d3698..8709dd07 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -11,6 +11,7 @@
from gudhi.point_cloud.dtm import DistanceToMeasure
import numpy
import pytest
+import torch
def test_dtm_compare_euclidean():
@@ -39,6 +40,13 @@ def test_dtm_compare_euclidean():
dtm = DistanceToMeasure(k, implementation="keops")
r5 = dtm.fit_transform(pts)
assert r5 == pytest.approx(r0)
+ pts2 = torch.tensor(pts, requires_grad=True)
+ assert pts2.grad is None
+ dtm = DistanceToMeasure(k, implementation="keops", enable_autodiff=True)
+ r6 = dtm.fit_transform(pts2)
+ assert r6.detach().numpy() == pytest.approx(r0)
+ r6.sum().backward()
+ assert pts2.grad is not None
def test_dtm_precomputed():