summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-25 18:53:08 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-25 18:53:08 +0200
commit87a142db9e133fbd8f08d9bcc70a51e2a907aa35 (patch)
tree53cdc31513f4522c3b30287f02f27d7a9cd506b9 /src/python/gudhi/clustering
parentfc7da6849c40cc0caef0e86e452f6d1e2c8320d0 (diff)
Document attribute weights_
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r--src/python/gudhi/clustering/tomato.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index fcb4b234..c4da9deb 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -4,8 +4,7 @@ from ..point_cloud.dtm import DTMDensity
from ._tomato import *
# The fit/predict interface is not so well suited...
-# TODO: option for a faster, weaker (probabilistic) knn
-
+# FIXME: choose if they are called weight, density, filtration, etc and be consistent.
class Tomato:
"""
@@ -21,14 +20,16 @@ class Tomato:
minimum prominence of a cluster so it doesn't get merged. Writing to it automatically adjusts labels_.
n_leaves_: int
number of leaves (unstable clusters) in the hierarchical tree
- leaf_labels_: ndarray of shape (n_samples)
+ leaf_labels_: ndarray of shape (n_samples,)
cluster labels for each point, at the very bottom of the hierarchy
- labels_: ndarray of shape (n_samples)
+ labels_: ndarray of shape (n_samples,)
cluster labels for each point, after merging
diagram_: ndarray of shape (n_leaves_,2)
persistence diagram (only the finite points)
children_: ndarray of shape (n_leaves_-1,2)
The children of each non-leaf node. Values less than n_leaves_ correspond to leaves of the tree. A node i greater than or equal to n_leaves_ is a non-leaf node and has children children_[i - n_leaves_]. Alternatively at the i-th iteration, children[i][0] and children[i][1] are merged to form node n_leaves_ + i
+ weights_: ndarray of shape (n_samples,)
+ weights of the points, as computed by the density estimator or provided by the user
params_: dict
Parameters like metric, etc
"""
@@ -180,12 +181,14 @@ class Tomato:
self.neighbors_ = [numpy.flatnonzero(l <= r) for l in X]
if self.density_type_ in {"KDE", "logKDE"}:
+ # Slow...
assert self.graph_type_ != "manual" and metric != "precomputed", "Scikit-learn's KernelDensity requires point coordinates"
kde_params = dict(self.params_.get("kde_params", dict()))
kde_params.setdefault("metric", metric)
r = self.params_.get("r")
if r is not None:
kde_params.setdefault("bandwidth", r)
+ # Should we default rtol to eps?
from sklearn.neighbors import KernelDensity
weights = KernelDensity(**kde_params).fit(self.points_).score_samples(self.points_)
if self.density_type_ == "KDE":
@@ -199,7 +202,7 @@ class Tomato:
for j in line:
self.neighbors_[j].add(i)
- self.weights_ = weights # TODO remove
+ self.weights_ = weights
self.leaf_labels_, self.children_, self.diagram_, self.max_density_per_cc_ = doit(
list(self.neighbors_), weights
)