summaryrefslogtreecommitdiff
path: root/src/python/test/test_dtm.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test/test_dtm.py')
-rwxr-xr-xsrc/python/test/test_dtm.py40
1 files changed, 29 insertions, 11 deletions
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index 93b13e1a..bff4c267 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -8,43 +8,61 @@
- YYYY/MM Author: Description of the modification
"""
-from gudhi.point_cloud.dtm import DTM
+from gudhi.point_cloud.dtm import DistanceToMeasure
import numpy
import pytest
+import torch
def test_dtm_compare_euclidean():
pts = numpy.random.rand(1000, 4)
- k = 3
- dtm = DTM(k, implementation="ckdtree")
+ k = 6
+ dtm = DistanceToMeasure(k, implementation="ckdtree")
r0 = dtm.fit_transform(pts)
- dtm = DTM(k, implementation="sklearn")
+ dtm = DistanceToMeasure(k, implementation="sklearn")
r1 = dtm.fit_transform(pts)
assert r1 == pytest.approx(r0)
- dtm = DTM(k, implementation="sklearn", algorithm="brute")
+ dtm = DistanceToMeasure(k, implementation="sklearn", algorithm="brute")
r2 = dtm.fit_transform(pts)
assert r2 == pytest.approx(r0)
- dtm = DTM(k, implementation="hnsw")
+ dtm = DistanceToMeasure(k, implementation="hnsw")
r3 = dtm.fit_transform(pts)
- assert r3 == pytest.approx(r0)
+ assert r3 == pytest.approx(r0, rel=0.1)
from scipy.spatial.distance import cdist
d = cdist(pts, pts)
- dtm = DTM(k, metric="precomputed")
+ dtm = DistanceToMeasure(k, metric="precomputed")
r4 = dtm.fit_transform(d)
assert r4 == pytest.approx(r0)
- dtm = DTM(k, implementation="keops")
+ dtm = DistanceToMeasure(k, metric="precomputed", n_jobs=2)
+ r4b = dtm.fit_transform(d)
+ assert r4b == pytest.approx(r0)
+ dtm = DistanceToMeasure(k, implementation="keops")
r5 = dtm.fit_transform(pts)
assert r5 == pytest.approx(r0)
+ pts2 = torch.tensor(pts, requires_grad=True)
+ assert pts2.grad is None
+ dtm = DistanceToMeasure(k, implementation="keops", enable_autodiff=True)
+ r6 = dtm.fit_transform(pts2)
+ assert r6.detach().numpy() == pytest.approx(r0)
+ r6.sum().backward()
+ assert not torch.isnan(pts2.grad).any()
+ pts2 = torch.tensor(pts, requires_grad=True)
+ assert pts2.grad is None
+ dtm = DistanceToMeasure(k, implementation="ckdtree", enable_autodiff=True)
+ r7 = dtm.fit_transform(pts2)
+ assert r7.detach().numpy() == pytest.approx(r0)
+ r7.sum().backward()
+ assert not torch.isnan(pts2.grad).any()
def test_dtm_precomputed():
dist = numpy.array([[1.0, 3, 8], [1, 5, 5], [0, 2, 3]])
- dtm = DTM(2, q=1, metric="neighbors")
+ dtm = DistanceToMeasure(2, q=1, metric="neighbors")
r = dtm.fit_transform(dist)
assert r == pytest.approx([2.0, 3, 1])
dist = numpy.array([[2.0, 2], [0, 1], [3, 4]])
- dtm = DTM(2, q=2, metric="neighbors")
+ dtm = DistanceToMeasure(2, q=2, metric="neighbors")
r = dtm.fit_transform(dist)
assert r == pytest.approx([2.0, 0.707, 3.5355], rel=0.01)