From c87a1f10e048477d210ae0abd657da87bba1102a Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Tue, 12 May 2020 20:36:38 +0200 Subject: test + reformat --- src/python/gudhi/point_cloud/dtm.py | 9 ++++++--- src/python/test/test_dtm.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py index f8cca2c1..4454d8a2 100644 --- a/src/python/gudhi/point_cloud/dtm.py +++ b/src/python/gudhi/point_cloud/dtm.py @@ -108,8 +108,8 @@ class DTMDensity: self.q = q self.dim = dim self.params = kwargs - self.normalize=normalize - self.n_samples=n_samples + self.normalize = normalize + self.n_samples = n_samples def fit_transform(self, X, y=None): return self.fit(X).transform(X) @@ -120,7 +120,9 @@ class DTMDensity: X (numpy.array): coordinates for mass points. """ if self.params.setdefault("metric", "euclidean") != "neighbors": - self.knn = KNearestNeighbors(self.k, return_index=False, return_distance=True, sort_results=False, **self.params) + self.knn = KNearestNeighbors( + self.k, return_index=False, return_distance=True, sort_results=False, **self.params + ) self.knn.fit(X) if self.params["metric"] != "precomputed": self.n_samples = len(X) @@ -154,6 +156,7 @@ class DTMDensity: density = dtm ** (-dim / q) if self.normalize: import math + if self.params["metric"] == "precomputed": self.n_samples = len(X[0]) # Volume of d-ball diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py index bff4c267..34d28d4d 100755 --- a/src/python/test/test_dtm.py +++ b/src/python/test/test_dtm.py @@ -8,10 +8,11 @@ - YYYY/MM Author: Description of the modification """ -from gudhi.point_cloud.dtm import DistanceToMeasure +from gudhi.point_cloud.dtm import DistanceToMeasure, DTMDensity import numpy import pytest import torch +import math def test_dtm_compare_euclidean(): @@ -66,3 +67,11 @@ def test_dtm_precomputed(): dtm = DistanceToMeasure(2, q=2, metric="neighbors") r = dtm.fit_transform(dist) assert r == pytest.approx([2.0, 0.707, 3.5355], rel=0.01) + + +def test_density_normalized(): + sample = numpy.random.normal(0, 1, (1000000, 2)) + queries = numpy.array([[0.0, 0.0], [-0.5, 0.7], [0.4, 1.7]]) + expected = numpy.exp(-(queries ** 2).sum(-1) / 2) / (2 * math.pi) + estimated = DTMDensity(k=150, normalize=True).fit(sample).transform(queries) + assert estimated == pytest.approx(expected, rel=0.4) -- cgit v1.2.3