summaryrefslogtreecommitdiff
path: root/src/python/gudhi
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-18 23:55:03 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-18 23:55:03 +0200
commit2cc9b9c608bf90e8d10029b8be97562801a6cb54 (patch)
treeb3ccbd852bf8d2b21adb91f13a2edafa7e752b84 /src/python/gudhi
parent2b896ce68eb5cf99d698313ca0e9eea3b35a19c6 (diff)
parent2287b727126ffb9fc47869ac9ed6b6bd61c6605a (diff)
Merge branch 'dtmdensity' into tomato2
Diffstat (limited to 'src/python/gudhi')
-rw-r--r--src/python/gudhi/point_cloud/dtm.py23
1 files changed, 17 insertions, 6 deletions
diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py
index 88f197e7..d836c28d 100644
--- a/src/python/gudhi/point_cloud/dtm.py
+++ b/src/python/gudhi/point_cloud/dtm.py
@@ -85,7 +85,8 @@ class DTMDensity:
def __init__(self, k=None, weights=None, q=None, dim=None, normalize=False, n_samples=None, **kwargs):
"""
Args:
- k (int): number of neighbors (possibly including the point itself).
+ k (int): number of neighbors (possibly including the point itself). Optional if it can be guessed
+ from weights or metric="neighbors".
weights (numpy.array): weights of each of the k neighbors, optional. They are supposed to sum to 1.
q (float): order used to compute the distance to measure. Defaults to dim.
dim (float): final exponent representing the dimension. Defaults to the dimension, and must be specified
@@ -98,9 +99,12 @@ class DTMDensity:
:func:`transform` expects an array with the distances to the k nearest neighbors.
"""
if weights is None:
- assert k is not None, "Must specify k or weights"
self.k = k
- self.weights = np.full(k, 1.0 / k)
+ if k is None:
+ assert kwargs.get("metric") == "neighbors", 'Must specify k or weights, unless metric is "neighbors"'
+ self.weights = None
+ else:
+ self.weights = np.full(k, 1.0 / k)
else:
self.weights = weights
self.k = len(weights)
@@ -145,14 +149,21 @@ class DTMDensity:
dim = len(X[0])
if q is None:
q = dim
+ k = self.k
+ weights = self.weights
if self.params["metric"] == "neighbors":
- distances = np.asarray(X)[:, : self.k]
+ distances = np.asarray(X)
+ if weights is None:
+ k = distances.shape[1]
+ weights = np.full(k, 1.0 / k)
+ else:
+ distances = distances[:, :k]
else:
distances = self.knn.transform(X)
distances = distances ** q
- dtm = (distances * self.weights).sum(-1)
+ dtm = (distances * weights).sum(-1)
if self.normalize:
- dtm /= (np.arange(1, self.k + 1) ** (q / dim) * self.weights).sum()
+ dtm /= (np.arange(1, k + 1) ** (q / dim) * weights).sum()
density = dtm ** (-dim / q)
if self.normalize:
import math