From c6f519abe3d2fe424d9982ad8139ff2a86119bca Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Mon, 25 May 2020 20:55:06 +0200 Subject: Make specifying an impossible number of clusters a warning --- src/python/gudhi/clustering/_tomato.cc | 18 ++++++++---------- src/python/gudhi/clustering/tomato.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 12 deletions(-) (limited to 'src/python/gudhi/clustering') diff --git a/src/python/gudhi/clustering/_tomato.cc b/src/python/gudhi/clustering/_tomato.cc index 87bd62e9..638e1259 100644 --- a/src/python/gudhi/clustering/_tomato.cc +++ b/src/python/gudhi/clustering/_tomato.cc @@ -195,21 +195,19 @@ auto tomato(Point_index num_points, Neighbors const& neighbors, Density const& d } auto merge(py::array_t children, Cluster_index n_leaves, Cluster_index n_final) { - // Should this really be an error? - if (n_final > n_leaves) - throw std::runtime_error("The number of clusters required is larger than the number of mini-clusters"); + if (n_final > n_leaves) { + std::cerr << "The number of clusters required " << n_final << " is larger than the number of mini-clusters " << n_leaves << '\n'; + n_final = n_leaves; // or return something special and let Tomato use leaf_labels_? + } py::buffer_info cbuf = children.request(); if ((cbuf.ndim != 2 || cbuf.shape[1] != 2) && (cbuf.ndim != 1 || cbuf.shape[0] != 0)) throw std::runtime_error("internal error: children have to be (n,2) or empty"); const int n_merges = cbuf.shape[0]; Cluster_index* d = (Cluster_index*)cbuf.ptr; - // Should this really be an error? - // std::cerr << "n_merges: " << n_merges << ", n_final: " << n_final << ", n_leaves: " << n_leaves << '\n'; - if (n_merges + n_final < n_leaves) - throw std::runtime_error(std::string("The number of clusters required ") + std::to_string(n_final) + - " is smaller than the number of connected components " + - std::to_string(n_leaves - n_merges)); - + if (n_merges + n_final < n_leaves) { + std::cerr << "The number of clusters required " << n_final << " is smaller than the number of connected components " << n_leaves - n_merges << '\n'; + n_final = n_leaves - n_merges; + } struct Dat { Cluster_index parent; int rank; diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index 7e97819b..867c46a1 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -24,11 +24,11 @@ class Tomato: cluster labels for each point, at the very bottom of the hierarchy labels_: ndarray of shape (n_samples,) cluster labels for each point, after merging - diagram_: ndarray of shape (n_leaves_,2) + diagram_: ndarray of shape (n_leaves_, 2) persistence diagram (only the finite points) max_weight_per_cc_: ndarray of shape (n_connected_components,) maximum of the density function on each connected component. This corresponds to the abscissa of infinite points in the diagram - children_: ndarray of shape (n_leaves_-1,2) + children_: ndarray of shape (n_leaves_-n_connected_components, 2) The children of each non-leaf node. Values less than n_leaves_ correspond to leaves of the tree. A node i greater than or equal to n_leaves_ is a non-leaf node and has children children_[i - n_leaves_]. Alternatively at the i-th iteration, children[i][0] and children[i][1] are merged to form node n_leaves_ + i weights_: ndarray of shape (n_samples,) weights of the points, as computed by the density estimator or provided by the user @@ -206,6 +206,7 @@ class Tomato: ) self.n_leaves_ = len(self.max_weight_per_cc_) + len(self.children_) assert self.leaf_labels_.max() + 1 == len(self.max_weight_per_cc_) + len(self.children_) + # TODO: deduplicate this code with the setters below if self.__merge_threshold: assert not self.__n_clusters self.__n_clusters = numpy.count_nonzero( @@ -215,6 +216,9 @@ class Tomato: # TODO: set corresponding merge_threshold? renaming = merge(self.children_, self.n_leaves_, self.__n_clusters) self.labels_ = renaming[self.leaf_labels_] + # In case the user asked for something impossible. + # TODO: check for impossible situations before calling merge. + self.__n_clusters = self.labels_.max() + 1 else: self.labels_ = self.leaf_labels_ self.__n_clusters = self.n_leaves_ @@ -269,6 +273,8 @@ class Tomato: if hasattr(self, "leaf_labels_"): renaming = merge(self.children_, self.n_leaves_, self.__n_clusters) self.labels_ = renaming[self.leaf_labels_] + # In case the user asked for something impossible + self.__n_clusters = self.labels_.max() + 1 @property def merge_threshold_(self): -- cgit v1.2.3