diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-05-25 20:55:06 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-05-25 20:55:06 +0200 |
commit | c6f519abe3d2fe424d9982ad8139ff2a86119bca (patch) | |
tree | 6cf4de5466cedf30d7b1d8f6e01f111a22075206 /src/python/gudhi/clustering/tomato.py | |
parent | 9a7fba8b3dcfbd838ce2ea571fd4e8f06cd8a7bd (diff) |
Make specifying an impossible number of clusters a warning
Diffstat (limited to 'src/python/gudhi/clustering/tomato.py')
-rw-r--r-- | src/python/gudhi/clustering/tomato.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index 7e97819b..867c46a1 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -24,11 +24,11 @@ class Tomato: cluster labels for each point, at the very bottom of the hierarchy labels_: ndarray of shape (n_samples,) cluster labels for each point, after merging - diagram_: ndarray of shape (n_leaves_,2) + diagram_: ndarray of shape (n_leaves_, 2) persistence diagram (only the finite points) max_weight_per_cc_: ndarray of shape (n_connected_components,) maximum of the density function on each connected component. This corresponds to the abscissa of infinite points in the diagram - children_: ndarray of shape (n_leaves_-1,2) + children_: ndarray of shape (n_leaves_-n_connected_components, 2) The children of each non-leaf node. Values less than n_leaves_ correspond to leaves of the tree. A node i greater than or equal to n_leaves_ is a non-leaf node and has children children_[i - n_leaves_]. Alternatively at the i-th iteration, children[i][0] and children[i][1] are merged to form node n_leaves_ + i weights_: ndarray of shape (n_samples,) weights of the points, as computed by the density estimator or provided by the user @@ -206,6 +206,7 @@ class Tomato: ) self.n_leaves_ = len(self.max_weight_per_cc_) + len(self.children_) assert self.leaf_labels_.max() + 1 == len(self.max_weight_per_cc_) + len(self.children_) + # TODO: deduplicate this code with the setters below if self.__merge_threshold: assert not self.__n_clusters self.__n_clusters = numpy.count_nonzero( @@ -215,6 +216,9 @@ class Tomato: # TODO: set corresponding merge_threshold? renaming = merge(self.children_, self.n_leaves_, self.__n_clusters) self.labels_ = renaming[self.leaf_labels_] + # In case the user asked for something impossible. + # TODO: check for impossible situations before calling merge. + self.__n_clusters = self.labels_.max() + 1 else: self.labels_ = self.leaf_labels_ self.__n_clusters = self.n_leaves_ @@ -269,6 +273,8 @@ class Tomato: if hasattr(self, "leaf_labels_"): renaming = merge(self.children_, self.n_leaves_, self.__n_clusters) self.labels_ = renaming[self.leaf_labels_] + # In case the user asked for something impossible + self.__n_clusters = self.labels_.max() + 1 @property def merge_threshold_(self): |