summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-03-28 15:39:15 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-03-28 15:39:15 +0100
commitdd9457649d0d197bbed6402200e0f2f55655680e (patch)
treead854cb4e02c506a8e3c979319ce7a86b146c4ae /src/python/gudhi/point_cloud
parent75286efcf311f0c7c46a7039970d663f60953e14 (diff)
Default param of 2 for DTM
Diffstat (limited to 'src/python/gudhi/point_cloud')
-rw-r--r--src/python/gudhi/point_cloud/dtm.py14
1 files changed, 4 insertions, 10 deletions
diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py
index 678524f2..c26ba844 100644
--- a/src/python/gudhi/point_cloud/dtm.py
+++ b/src/python/gudhi/point_cloud/dtm.py
@@ -15,11 +15,11 @@ class DTM:
Class to compute the distance to the empirical measure defined by a point set.
"""
- def __init__(self, k, q=None, **kwargs):
+ def __init__(self, k, q=2, **kwargs):
"""
Args:
k (int): number of neighbors (possibly including the point itself).
- q (float): order used to compute the distance to measure. Defaults to the dimension, or 2 if metric is "neighbors" or "distance_matrix".
+ q (float): order used to compute the distance to measure. Defaults to 2.
kwargs: same parameters as :class:`~gudhi.point_cloud.knn.KNN`, except that metric="neighbors" means that :func:`transform` expects an array with the distances to the k nearest neighbors.
"""
self.k = k
@@ -44,19 +44,13 @@ class DTM:
Args:
X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed", or distances to the k nearest neighbors if metric is "neighbors" (if the array has more than k columns, the remaining ones are ignored).
"""
- q = self.q
- if q is None:
- if self.params["metric"] in {"neighbors", "precomputed"}:
- q = 2
- else:
- q = len(X[0])
if self.params["metric"] == "neighbors":
distances = X[:, : self.k]
else:
distances = self.knn.transform(X)
- distances = distances ** q
+ distances = distances ** self.q
dtm = distances.sum(-1) / self.k
- dtm = dtm ** (1.0 / q)
+ dtm = dtm ** (1.0 / self.q)
# We compute too many powers, 1/p in knn then q in dtm, 1/q in dtm then q or some log in the caller.
# Add option to skip the final root?
return dtm