summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-18 23:54:02 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-18 23:54:02 +0200
commit2287b727126ffb9fc47869ac9ed6b6bd61c6605a (patch)
treea4bd39b51dd3e59cd18d6b634d007bd97a635fdd /src/python/gudhi/point_cloud
parent5631b0d1d9f7cc7e033e40fb9b94c8fe473f6082 (diff)
Infer k when we pass the distances to the nearest neighbors
Diffstat (limited to 'src/python/gudhi/point_cloud')
-rw-r--r--src/python/gudhi/point_cloud/dtm.py23
1 files changed, 17 insertions, 6 deletions
diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py
index 88f197e7..d836c28d 100644
--- a/src/python/gudhi/point_cloud/dtm.py
+++ b/src/python/gudhi/point_cloud/dtm.py
@@ -85,7 +85,8 @@ class DTMDensity:
def __init__(self, k=None, weights=None, q=None, dim=None, normalize=False, n_samples=None, **kwargs):
"""
Args:
- k (int): number of neighbors (possibly including the point itself).
+ k (int): number of neighbors (possibly including the point itself). Optional if it can be guessed
+ from weights or metric="neighbors".
weights (numpy.array): weights of each of the k neighbors, optional. They are supposed to sum to 1.
q (float): order used to compute the distance to measure. Defaults to dim.
dim (float): final exponent representing the dimension. Defaults to the dimension, and must be specified
@@ -98,9 +99,12 @@ class DTMDensity:
:func:`transform` expects an array with the distances to the k nearest neighbors.
"""
if weights is None:
- assert k is not None, "Must specify k or weights"
self.k = k
- self.weights = np.full(k, 1.0 / k)
+ if k is None:
+ assert kwargs.get("metric") == "neighbors", 'Must specify k or weights, unless metric is "neighbors"'
+ self.weights = None
+ else:
+ self.weights = np.full(k, 1.0 / k)
else:
self.weights = weights
self.k = len(weights)
@@ -145,14 +149,21 @@ class DTMDensity:
dim = len(X[0])
if q is None:
q = dim
+ k = self.k
+ weights = self.weights
if self.params["metric"] == "neighbors":
- distances = np.asarray(X)[:, : self.k]
+ distances = np.asarray(X)
+ if weights is None:
+ k = distances.shape[1]
+ weights = np.full(k, 1.0 / k)
+ else:
+ distances = distances[:, :k]
else:
distances = self.knn.transform(X)
distances = distances ** q
- dtm = (distances * self.weights).sum(-1)
+ dtm = (distances * weights).sum(-1)
if self.normalize:
- dtm /= (np.arange(1, self.k + 1) ** (q / dim) * self.weights).sum()
+ dtm /= (np.arange(1, k + 1) ** (q / dim) * weights).sum()
density = dtm ** (-dim / q)
if self.normalize:
import math