summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-18 23:54:02 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-18 23:54:02 +0200
commit2287b727126ffb9fc47869ac9ed6b6bd61c6605a (patch)
treea4bd39b51dd3e59cd18d6b634d007bd97a635fdd
parent5631b0d1d9f7cc7e033e40fb9b94c8fe473f6082 (diff)
Infer k when we pass the distances to the nearest neighbors
-rw-r--r--src/python/gudhi/point_cloud/dtm.py23
-rwxr-xr-xsrc/python/test/test_dtm.py4
2 files changed, 21 insertions, 6 deletions
diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py
index 88f197e7..d836c28d 100644
--- a/src/python/gudhi/point_cloud/dtm.py
+++ b/src/python/gudhi/point_cloud/dtm.py
@@ -85,7 +85,8 @@ class DTMDensity:
def __init__(self, k=None, weights=None, q=None, dim=None, normalize=False, n_samples=None, **kwargs):
"""
Args:
- k (int): number of neighbors (possibly including the point itself).
+ k (int): number of neighbors (possibly including the point itself). Optional if it can be guessed
+ from weights or metric="neighbors".
weights (numpy.array): weights of each of the k neighbors, optional. They are supposed to sum to 1.
q (float): order used to compute the distance to measure. Defaults to dim.
dim (float): final exponent representing the dimension. Defaults to the dimension, and must be specified
@@ -98,9 +99,12 @@ class DTMDensity:
:func:`transform` expects an array with the distances to the k nearest neighbors.
"""
if weights is None:
- assert k is not None, "Must specify k or weights"
self.k = k
- self.weights = np.full(k, 1.0 / k)
+ if k is None:
+ assert kwargs.get("metric") == "neighbors", 'Must specify k or weights, unless metric is "neighbors"'
+ self.weights = None
+ else:
+ self.weights = np.full(k, 1.0 / k)
else:
self.weights = weights
self.k = len(weights)
@@ -145,14 +149,21 @@ class DTMDensity:
dim = len(X[0])
if q is None:
q = dim
+ k = self.k
+ weights = self.weights
if self.params["metric"] == "neighbors":
- distances = np.asarray(X)[:, : self.k]
+ distances = np.asarray(X)
+ if weights is None:
+ k = distances.shape[1]
+ weights = np.full(k, 1.0 / k)
+ else:
+ distances = distances[:, :k]
else:
distances = self.knn.transform(X)
distances = distances ** q
- dtm = (distances * self.weights).sum(-1)
+ dtm = (distances * weights).sum(-1)
if self.normalize:
- dtm /= (np.arange(1, self.k + 1) ** (q / dim) * self.weights).sum()
+ dtm /= (np.arange(1, k + 1) ** (q / dim) * weights).sum()
density = dtm ** (-dim / q)
if self.normalize:
import math
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index 8ab0cc44..8d400c7e 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -82,3 +82,7 @@ def test_density():
density = DTMDensity(k=2, metric="neighbors", dim=1).fit_transform(distances)
expected = numpy.array([2.0, 1.0, 0.5])
assert density == pytest.approx(expected)
+ distances = [[0, 1], [2, 0], [1, 3]]
+ density = DTMDensity(metric="neighbors", dim=1).fit_transform(distances)
+ expected = numpy.array([2.0, 1.0, 0.5])
+ assert density == pytest.approx(expected)