summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-25 22:56:04 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-25 22:56:04 +0200
commit2fb0d594060958804239fcdad5336832ea5133d0 (patch)
treeb9da0c4a87417f9ab168d7d3c9b761d61650ea41 /src/python/gudhi/clustering
parentcaa7c97d812acc3559aaecedc6e44e5f41d8a6af (diff)
Add test
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r--src/python/gudhi/clustering/tomato.py18
1 files changed, 8 insertions, 10 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 55b64c1d..e3eaa300 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -18,24 +18,24 @@ class Tomato:
Attributes
----------
n_clusters_: int
- The number of clusters. Writing to it automatically adjusts labels_.
+ The number of clusters. Writing to it automatically adjusts `labels_`.
merge_threshold_: float
- minimum prominence of a cluster so it doesn't get merged. Writing to it automatically adjusts labels_.
+ minimum prominence of a cluster so it doesn't get merged. Writing to it automatically adjusts `labels_`.
n_leaves_: int
number of leaves (unstable clusters) in the hierarchical tree
leaf_labels_: ndarray of shape (n_samples,)
cluster labels for each point, at the very bottom of the hierarchy
labels_: ndarray of shape (n_samples,)
cluster labels for each point, after merging
- diagram_: ndarray of shape (n_leaves_, 2)
+ diagram_: ndarray of shape (`n_leaves_`, 2)
persistence diagram (only the finite points)
max_weight_per_cc_: ndarray of shape (n_connected_components,)
maximum of the density function on each connected component. This corresponds to the abscissa of infinite
points in the diagram
- children_: ndarray of shape (n_leaves_-n_connected_components, 2)
- The children of each non-leaf node. Values less than n_leaves_ correspond to leaves of the tree.
- A node i greater than or equal to n_leaves_ is a non-leaf node and has children children_[i - n_leaves_].
- Alternatively at the i-th iteration, children[i][0] and children[i][1] are merged to form node n_leaves_ + i
+ children_: ndarray of shape (`n_leaves_`-n_connected_components, 2)
+ The children of each non-leaf node. Values less than `n_leaves_` correspond to leaves of the tree.
+ A node i greater than or equal to `n_leaves_` is a non-leaf node and has children children_[i - `n_leaves_`].
+ Alternatively at the i-th iteration, children[i][0] and children[i][1] are merged to form node `n_leaves_` + i
weights_: ndarray of shape (n_samples,)
weights of the points, as computed by the density estimator or provided by the user
params_: dict
@@ -53,8 +53,6 @@ class Tomato:
**params
):
"""
- Each parameter has a corresponding attribute, like self.merge_threshold_, that can be changed later.
-
Args:
graph_type (str): 'manual', 'knn' or 'radius'.
density_type (str): 'manual', 'DTM', 'logDTM', 'KDE' or 'logKDE'.
@@ -223,7 +221,7 @@ class Tomato:
self.neighbors_[j].add(i)
self.weights_ = weights
- self.leaf_labels_, self.children_, self.diagram_, self.max_weight_per_cc_ = doit(list(self.neighbors_), weights)
+ self.leaf_labels_, self.children_, self.diagram_, self.max_weight_per_cc_ = doit(self.neighbors_, weights)
self.n_leaves_ = len(self.max_weight_per_cc_) + len(self.children_)
assert self.leaf_labels_.max() + 1 == len(self.max_weight_per_cc_) + len(self.children_)
# TODO: deduplicate this code with the setters below