summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering/tomato.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/clustering/tomato.py')
-rw-r--r--src/python/gudhi/clustering/tomato.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 7e97819b..867c46a1 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -24,11 +24,11 @@ class Tomato:
cluster labels for each point, at the very bottom of the hierarchy
labels_: ndarray of shape (n_samples,)
cluster labels for each point, after merging
- diagram_: ndarray of shape (n_leaves_,2)
+ diagram_: ndarray of shape (n_leaves_, 2)
persistence diagram (only the finite points)
max_weight_per_cc_: ndarray of shape (n_connected_components,)
maximum of the density function on each connected component. This corresponds to the abscissa of infinite points in the diagram
- children_: ndarray of shape (n_leaves_-1,2)
+ children_: ndarray of shape (n_leaves_-n_connected_components, 2)
The children of each non-leaf node. Values less than n_leaves_ correspond to leaves of the tree. A node i greater than or equal to n_leaves_ is a non-leaf node and has children children_[i - n_leaves_]. Alternatively at the i-th iteration, children[i][0] and children[i][1] are merged to form node n_leaves_ + i
weights_: ndarray of shape (n_samples,)
weights of the points, as computed by the density estimator or provided by the user
@@ -206,6 +206,7 @@ class Tomato:
)
self.n_leaves_ = len(self.max_weight_per_cc_) + len(self.children_)
assert self.leaf_labels_.max() + 1 == len(self.max_weight_per_cc_) + len(self.children_)
+ # TODO: deduplicate this code with the setters below
if self.__merge_threshold:
assert not self.__n_clusters
self.__n_clusters = numpy.count_nonzero(
@@ -215,6 +216,9 @@ class Tomato:
# TODO: set corresponding merge_threshold?
renaming = merge(self.children_, self.n_leaves_, self.__n_clusters)
self.labels_ = renaming[self.leaf_labels_]
+ # In case the user asked for something impossible.
+ # TODO: check for impossible situations before calling merge.
+ self.__n_clusters = self.labels_.max() + 1
else:
self.labels_ = self.leaf_labels_
self.__n_clusters = self.n_leaves_
@@ -269,6 +273,8 @@ class Tomato:
if hasattr(self, "leaf_labels_"):
renaming = merge(self.children_, self.n_leaves_, self.__n_clusters)
self.labels_ = renaming[self.leaf_labels_]
+ # In case the user asked for something impossible
+ self.__n_clusters = self.labels_.max() + 1
@property
def merge_threshold_(self):