summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-03-28 11:48:43 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-03-28 11:48:43 +0100
commit35a12b553c85af8ce31629b90a27a7071b0cc379 (patch)
tree795fbe61b893a88cd0ac64c249dad276bcf36de2 /src/python/gudhi/point_cloud
parent68839b95e7751afd04155cd2565cc53362f01fa2 (diff)
Doc tweaks, default DTM exponent
Diffstat (limited to 'src/python/gudhi/point_cloud')
-rw-r--r--src/python/gudhi/point_cloud/dtm.py17
-rw-r--r--src/python/gudhi/point_cloud/knn.py6
2 files changed, 15 insertions, 8 deletions
diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py
index e4096c5e..520cbea8 100644
--- a/src/python/gudhi/point_cloud/dtm.py
+++ b/src/python/gudhi/point_cloud/dtm.py
@@ -15,10 +15,11 @@ class DTM:
Class to compute the distance to the empirical measure defined by a point set.
"""
- def __init__(self, k, q=2, **kwargs):
+ def __init__(self, k, q=None, **kwargs):
"""
Args:
- q (float): order used to compute the distance to measure. Defaults to the dimension, or 2 if input_type is 'distance_matrix'.
+ k (int): number of neighbors (possibly including the point itself).
+ q (float): order used to compute the distance to measure. Defaults to the dimension, or 2 if metric is "neighbors" or "distance_matrix".
kwargs: Same parameters as :class:`~gudhi.point_cloud.knn.KNN`, except that metric="neighbors" means that :func:`transform` expects an array with the distances to the k nearest neighbors.
"""
self.k = k
@@ -31,7 +32,7 @@ class DTM:
def fit(self, X, y=None):
"""
Args:
- X (numpy.array): coordinates for mass points
+ X (numpy.array): coordinates for mass points.
"""
if self.params.setdefault("metric", "euclidean") != "neighbors":
# KNN gives sorted distances, which is unnecessary here.
@@ -45,11 +46,17 @@ class DTM:
Args:
X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed", or distances to the k nearest neighbors if metric is "neighbors" (if the array has more than k columns, the remaining ones are ignored).
"""
+ q = self.q
+ if q is None:
+ if self.params["metric"] in {"neighbors", "precomputed"}:
+ q = 2
+ else:
+ q = len(X[0])
if self.params["metric"] == "neighbors":
distances = X[:, : self.k]
else:
distances = self.knn.transform(X)
- distances = distances ** self.q
+ distances = distances ** q
dtm = distances.sum(-1) / self.k
- dtm = dtm ** (1.0 / self.q)
+ dtm = dtm ** (1.0 / q)
return dtm
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 02448530..31e4fc9f 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -18,7 +18,7 @@ class KNN:
def __init__(self, k, return_index=True, return_distance=False, metric="euclidean", **kwargs):
"""
Args:
- k (int): number of neighbors (including the point itself).
+ k (int): number of neighbors (possibly including the point itself).
return_index (bool): if True, return the index of each neighbor.
return_distance (bool): if True, return the distance to each neighbor.
implementation (str): Choice of the library that does the real work.
@@ -68,7 +68,7 @@ class KNN:
def fit(self, X, y=None):
"""
Args:
- X (numpy.array): coordinates for reference points
+ X (numpy.array): coordinates for reference points.
"""
self.ref_points = X
if self.params["implementation"] == "ckdtree":
@@ -105,7 +105,7 @@ class KNN:
def transform(self, X):
"""
Args:
- X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed"
+ X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed".
"""
metric = self.metric
k = self.k