From c5fca5477cc6fff77acedf7b5324eb5f8b417ed3 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Tue, 12 May 2020 22:31:42 +0200 Subject: More test --- src/python/doc/point_cloud_sum.inc | 4 ++-- src/python/gudhi/point_cloud/dtm.py | 4 ++-- src/python/test/test_dtm.py | 7 +++++++ 3 files changed, 11 insertions(+), 4 deletions(-) (limited to 'src/python') diff --git a/src/python/doc/point_cloud_sum.inc b/src/python/doc/point_cloud_sum.inc index d4761aba..d28f387a 100644 --- a/src/python/doc/point_cloud_sum.inc +++ b/src/python/doc/point_cloud_sum.inc @@ -3,8 +3,8 @@ +----------------------------------------------------------------+------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------+ | | :math:`(x_1, x_2, \ldots, x_d)` | Utilities to process point clouds: read from file, subsample, | :Authors: Vincent Rouvreau, Marc Glisse, Masatoshi Takenouchi | - | | :math:`(y_1, y_2, \ldots, y_d)` | find neighbors, embed time series in higher dimension, etc. | | - | | | :Since: GUDHI 2.0.0 | + | | :math:`(y_1, y_2, \ldots, y_d)` | find neighbors, embed time series in higher dimension, estimate | | + | | a density, etc. | :Since: GUDHI 2.0.0 | | | | | | | | :License: MIT (`GPL v3 `_, BSD-3-Clause, Apache-2.0) | | | Parts of this package require CGAL. | | diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py index 4454d8a2..88f197e7 100644 --- a/src/python/gudhi/point_cloud/dtm.py +++ b/src/python/gudhi/point_cloud/dtm.py @@ -89,7 +89,7 @@ class DTMDensity: weights (numpy.array): weights of each of the k neighbors, optional. They are supposed to sum to 1. q (float): order used to compute the distance to measure. Defaults to dim. dim (float): final exponent representing the dimension. Defaults to the dimension, and must be specified - when the dimension cannot be read from the input (metric="neighbors" or metric="precomputed"). + when the dimension cannot be read from the input (metric is "neighbors" or "precomputed"). normalize (bool): normalize the density so it corresponds to a probability measure on ℝᵈ. Only available for the Euclidean metric, defaults to False. n_samples (int): number of sample points used for fitting. Only needed if `normalize` is True and @@ -146,7 +146,7 @@ class DTMDensity: if q is None: q = dim if self.params["metric"] == "neighbors": - distances = X[:, : self.k] + distances = np.asarray(X)[:, : self.k] else: distances = self.knn.transform(X) distances = distances ** q diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py index 34d28d4d..8ab0cc44 100755 --- a/src/python/test/test_dtm.py +++ b/src/python/test/test_dtm.py @@ -75,3 +75,10 @@ def test_density_normalized(): expected = numpy.exp(-(queries ** 2).sum(-1) / 2) / (2 * math.pi) estimated = DTMDensity(k=150, normalize=True).fit(sample).transform(queries) assert estimated == pytest.approx(expected, rel=0.4) + + +def test_density(): + distances = [[0, 1, 10], [2, 0, 30], [1, 3, 5]] + density = DTMDensity(k=2, metric="neighbors", dim=1).fit_transform(distances) + expected = numpy.array([2.0, 1.0, 0.5]) + assert density == pytest.approx(expected) -- cgit v1.2.3