summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/dtm.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-29 23:00:17 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-29 23:00:17 +0200
commit73a74011e4b5af0794f0463295beca924d32e0ee (patch)
treefb825dc77e9984594109f9ea48838d0544dfca1e /src/python/gudhi/point_cloud/dtm.py
parent74155081bb8b3330c562d5c40d7f0a32fc188012 (diff)
parent0bba67db83f33ff608366057d9c4f005fa6a514b (diff)
Merge remote-tracking branch 'origin/master' into dtmdensity
Diffstat (limited to 'src/python/gudhi/point_cloud/dtm.py')
-rw-r--r--src/python/gudhi/point_cloud/dtm.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py
index e12eefa1..c5405526 100644
--- a/src/python/gudhi/point_cloud/dtm.py
+++ b/src/python/gudhi/point_cloud/dtm.py
@@ -7,11 +7,15 @@
# Modification(s):
# - YYYY/MM Author: Description of the modification
-from .knn import KNN
+from .knn import KNearestNeighbors
import numpy as np
+__author__ = "Marc Glisse"
+__copyright__ = "Copyright (C) 2020 Inria"
+__license__ = "MIT"
-class DTM:
+
+class DistanceToMeasure:
"""
Class to compute the distance to the empirical measure defined by a point set, as introduced in :cite:`dtm`.
"""
@@ -21,7 +25,9 @@ class DTM:
Args:
k (int): number of neighbors (possibly including the point itself).
q (float): order used to compute the distance to measure. Defaults to 2.
- kwargs: same parameters as :class:`~gudhi.point_cloud.knn.KNN`, except that metric="neighbors" means that :func:`transform` expects an array with the distances to the k nearest neighbors.
+ kwargs: same parameters as :class:`~gudhi.point_cloud.knn.KNearestNeighbors`, except that
+ metric="neighbors" means that :func:`transform` expects an array with the distances
+ to the k nearest neighbors.
"""
self.k = k
self.q = q
@@ -36,14 +42,22 @@ class DTM:
X (numpy.array): coordinates for mass points.
"""
if self.params.setdefault("metric", "euclidean") != "neighbors":
- self.knn = KNN(self.k, return_index=False, return_distance=True, sort_results=False, **self.params)
+ self.knn = KNearestNeighbors(
+ self.k, return_index=False, return_distance=True, sort_results=False, **self.params
+ )
self.knn.fit(X)
return self
def transform(self, X):
"""
Args:
- X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed", or distances to the k nearest neighbors if metric is "neighbors" (if the array has more than k columns, the remaining ones are ignored).
+ X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed",
+ or distances to the k nearest neighbors if metric is "neighbors" (if the array has more
+ than k columns, the remaining ones are ignored).
+
+ Returns:
+ numpy.array: a 1-d array with, for each point of X, its distance to the measure defined
+ by the argument of :func:`fit`.
"""
if self.params["metric"] == "neighbors":
distances = X[:, : self.k]