summaryrefslogtreecommitdiff
path: root/src/python/test/test_dtm.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test/test_dtm.py')
-rwxr-xr-xsrc/python/test/test_dtm.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index bc0d3698..8709dd07 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -11,6 +11,7 @@
from gudhi.point_cloud.dtm import DistanceToMeasure
import numpy
import pytest
+import torch
def test_dtm_compare_euclidean():
@@ -39,6 +40,13 @@ def test_dtm_compare_euclidean():
dtm = DistanceToMeasure(k, implementation="keops")
r5 = dtm.fit_transform(pts)
assert r5 == pytest.approx(r0)
+ pts2 = torch.tensor(pts, requires_grad=True)
+ assert pts2.grad is None
+ dtm = DistanceToMeasure(k, implementation="keops", enable_autodiff=True)
+ r6 = dtm.fit_transform(pts2)
+ assert r6.detach().numpy() == pytest.approx(r0)
+ r6.sum().backward()
+ assert pts2.grad is not None
def test_dtm_precomputed():