summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/dtm.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/point_cloud/dtm.py')
-rw-r--r--src/python/gudhi/point_cloud/dtm.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py
index e4096c5e..520cbea8 100644
--- a/src/python/gudhi/point_cloud/dtm.py
+++ b/src/python/gudhi/point_cloud/dtm.py
@@ -15,10 +15,11 @@ class DTM:
Class to compute the distance to the empirical measure defined by a point set.
"""
- def __init__(self, k, q=2, **kwargs):
+ def __init__(self, k, q=None, **kwargs):
"""
Args:
- q (float): order used to compute the distance to measure. Defaults to the dimension, or 2 if input_type is 'distance_matrix'.
+ k (int): number of neighbors (possibly including the point itself).
+ q (float): order used to compute the distance to measure. Defaults to the dimension, or 2 if metric is "neighbors" or "distance_matrix".
kwargs: Same parameters as :class:`~gudhi.point_cloud.knn.KNN`, except that metric="neighbors" means that :func:`transform` expects an array with the distances to the k nearest neighbors.
"""
self.k = k
@@ -31,7 +32,7 @@ class DTM:
def fit(self, X, y=None):
"""
Args:
- X (numpy.array): coordinates for mass points
+ X (numpy.array): coordinates for mass points.
"""
if self.params.setdefault("metric", "euclidean") != "neighbors":
# KNN gives sorted distances, which is unnecessary here.
@@ -45,11 +46,17 @@ class DTM:
Args:
X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed", or distances to the k nearest neighbors if metric is "neighbors" (if the array has more than k columns, the remaining ones are ignored).
"""
+ q = self.q
+ if q is None:
+ if self.params["metric"] in {"neighbors", "precomputed"}:
+ q = 2
+ else:
+ q = len(X[0])
if self.params["metric"] == "neighbors":
distances = X[:, : self.k]
else:
distances = self.knn.transform(X)
- distances = distances ** self.q
+ distances = distances ** q
dtm = distances.sum(-1) / self.k
- dtm = dtm ** (1.0 / self.q)
+ dtm = dtm ** (1.0 / q)
return dtm