summaryrefslogtreecommitdiff
path: root/src/python/test/test_dtm.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test/test_dtm.py')
-rwxr-xr-xsrc/python/test/test_dtm.py23
1 files changed, 22 insertions, 1 deletions
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index bff4c267..0a52279e 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -8,10 +8,11 @@
- YYYY/MM Author: Description of the modification
"""
-from gudhi.point_cloud.dtm import DistanceToMeasure
+from gudhi.point_cloud.dtm import DistanceToMeasure, DTMDensity
import numpy
import pytest
import torch
+import math
def test_dtm_compare_euclidean():
@@ -66,3 +67,23 @@ def test_dtm_precomputed():
dtm = DistanceToMeasure(2, q=2, metric="neighbors")
r = dtm.fit_transform(dist)
assert r == pytest.approx([2.0, 0.707, 3.5355], rel=0.01)
+
+
+def test_density_normalized():
+ sample = numpy.random.normal(0, 1, (1000000, 2))
+ queries = numpy.array([[0.0, 0.0], [-0.5, 0.7], [0.4, 1.7]])
+ expected = numpy.exp(-(queries ** 2).sum(-1) / 2) / (2 * math.pi)
+ estimated = DTMDensity(k=150, normalize=True).fit(sample).transform(queries)
+ assert estimated == pytest.approx(expected, rel=0.4)
+
+
+def test_density():
+ distances = [[0, 1, 10], [2, 0, 30], [1, 3, 5]]
+ density = DTMDensity(k=2, metric="neighbors", dim=1).fit_transform(distances)
+ expected = numpy.array([2.0, 1.0, 0.5])
+ assert density == pytest.approx(expected)
+ distances = [[0, 1], [2, 0], [1, 3]]
+ density = DTMDensity(metric="neighbors", dim=1).fit_transform(distances)
+ assert density == pytest.approx(expected)
+ density = DTMDensity(weights=[0.5, 0.5], metric="neighbors", dim=1).fit_transform(distances)
+ assert density == pytest.approx(expected)