diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-04-29 23:00:17 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-04-29 23:00:17 +0200 |
commit | 73a74011e4b5af0794f0463295beca924d32e0ee (patch) | |
tree | fb825dc77e9984594109f9ea48838d0544dfca1e /src/python/gudhi/point_cloud/dtm.py | |
parent | 74155081bb8b3330c562d5c40d7f0a32fc188012 (diff) | |
parent | 0bba67db83f33ff608366057d9c4f005fa6a514b (diff) |
Merge remote-tracking branch 'origin/master' into dtmdensity
Diffstat (limited to 'src/python/gudhi/point_cloud/dtm.py')
-rw-r--r-- | src/python/gudhi/point_cloud/dtm.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py index e12eefa1..c5405526 100644 --- a/src/python/gudhi/point_cloud/dtm.py +++ b/src/python/gudhi/point_cloud/dtm.py @@ -7,11 +7,15 @@ # Modification(s): # - YYYY/MM Author: Description of the modification -from .knn import KNN +from .knn import KNearestNeighbors import numpy as np +__author__ = "Marc Glisse" +__copyright__ = "Copyright (C) 2020 Inria" +__license__ = "MIT" -class DTM: + +class DistanceToMeasure: """ Class to compute the distance to the empirical measure defined by a point set, as introduced in :cite:`dtm`. """ @@ -21,7 +25,9 @@ class DTM: Args: k (int): number of neighbors (possibly including the point itself). q (float): order used to compute the distance to measure. Defaults to 2. - kwargs: same parameters as :class:`~gudhi.point_cloud.knn.KNN`, except that metric="neighbors" means that :func:`transform` expects an array with the distances to the k nearest neighbors. + kwargs: same parameters as :class:`~gudhi.point_cloud.knn.KNearestNeighbors`, except that + metric="neighbors" means that :func:`transform` expects an array with the distances + to the k nearest neighbors. """ self.k = k self.q = q @@ -36,14 +42,22 @@ class DTM: X (numpy.array): coordinates for mass points. """ if self.params.setdefault("metric", "euclidean") != "neighbors": - self.knn = KNN(self.k, return_index=False, return_distance=True, sort_results=False, **self.params) + self.knn = KNearestNeighbors( + self.k, return_index=False, return_distance=True, sort_results=False, **self.params + ) self.knn.fit(X) return self def transform(self, X): """ Args: - X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed", or distances to the k nearest neighbors if metric is "neighbors" (if the array has more than k columns, the remaining ones are ignored). + X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed", + or distances to the k nearest neighbors if metric is "neighbors" (if the array has more + than k columns, the remaining ones are ignored). + + Returns: + numpy.array: a 1-d array with, for each point of X, its distance to the measure defined + by the argument of :func:`fit`. """ if self.params["metric"] == "neighbors": distances = X[:, : self.k] |