summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-12 20:36:38 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-12 20:36:38 +0200
commitc87a1f10e048477d210ae0abd657da87bba1102a (patch)
tree29d681e85bfb0d5b4eebdc17ce1a0b567704b7da
parent7bbbe63ffa2a812dc49c37c77b4f4a4be46b2a49 (diff)
test + reformat
-rw-r--r--src/python/gudhi/point_cloud/dtm.py9
-rwxr-xr-xsrc/python/test/test_dtm.py11
2 files changed, 16 insertions, 4 deletions
diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py
index f8cca2c1..4454d8a2 100644
--- a/src/python/gudhi/point_cloud/dtm.py
+++ b/src/python/gudhi/point_cloud/dtm.py
@@ -108,8 +108,8 @@ class DTMDensity:
self.q = q
self.dim = dim
self.params = kwargs
- self.normalize=normalize
- self.n_samples=n_samples
+ self.normalize = normalize
+ self.n_samples = n_samples
def fit_transform(self, X, y=None):
return self.fit(X).transform(X)
@@ -120,7 +120,9 @@ class DTMDensity:
X (numpy.array): coordinates for mass points.
"""
if self.params.setdefault("metric", "euclidean") != "neighbors":
- self.knn = KNearestNeighbors(self.k, return_index=False, return_distance=True, sort_results=False, **self.params)
+ self.knn = KNearestNeighbors(
+ self.k, return_index=False, return_distance=True, sort_results=False, **self.params
+ )
self.knn.fit(X)
if self.params["metric"] != "precomputed":
self.n_samples = len(X)
@@ -154,6 +156,7 @@ class DTMDensity:
density = dtm ** (-dim / q)
if self.normalize:
import math
+
if self.params["metric"] == "precomputed":
self.n_samples = len(X[0])
# Volume of d-ball
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index bff4c267..34d28d4d 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -8,10 +8,11 @@
- YYYY/MM Author: Description of the modification
"""
-from gudhi.point_cloud.dtm import DistanceToMeasure
+from gudhi.point_cloud.dtm import DistanceToMeasure, DTMDensity
import numpy
import pytest
import torch
+import math
def test_dtm_compare_euclidean():
@@ -66,3 +67,11 @@ def test_dtm_precomputed():
dtm = DistanceToMeasure(2, q=2, metric="neighbors")
r = dtm.fit_transform(dist)
assert r == pytest.approx([2.0, 0.707, 3.5355], rel=0.01)
+
+
+def test_density_normalized():
+ sample = numpy.random.normal(0, 1, (1000000, 2))
+ queries = numpy.array([[0.0, 0.0], [-0.5, 0.7], [0.4, 1.7]])
+ expected = numpy.exp(-(queries ** 2).sum(-1) / 2) / (2 * math.pi)
+ estimated = DTMDensity(k=150, normalize=True).fit(sample).transform(queries)
+ assert estimated == pytest.approx(expected, rel=0.4)