summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-25 20:55:06 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-25 20:55:06 +0200
commitc6f519abe3d2fe424d9982ad8139ff2a86119bca (patch)
tree6cf4de5466cedf30d7b1d8f6e01f111a22075206 /src/python/gudhi/clustering
parent9a7fba8b3dcfbd838ce2ea571fd4e8f06cd8a7bd (diff)
Make specifying an impossible number of clusters a warning
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r--src/python/gudhi/clustering/_tomato.cc18
-rw-r--r--src/python/gudhi/clustering/tomato.py10
2 files changed, 16 insertions, 12 deletions
diff --git a/src/python/gudhi/clustering/_tomato.cc b/src/python/gudhi/clustering/_tomato.cc
index 87bd62e9..638e1259 100644
--- a/src/python/gudhi/clustering/_tomato.cc
+++ b/src/python/gudhi/clustering/_tomato.cc
@@ -195,21 +195,19 @@ auto tomato(Point_index num_points, Neighbors const& neighbors, Density const& d
}
auto merge(py::array_t<Cluster_index, py::array::c_style> children, Cluster_index n_leaves, Cluster_index n_final) {
- // Should this really be an error?
- if (n_final > n_leaves)
- throw std::runtime_error("The number of clusters required is larger than the number of mini-clusters");
+ if (n_final > n_leaves) {
+ std::cerr << "The number of clusters required " << n_final << " is larger than the number of mini-clusters " << n_leaves << '\n';
+ n_final = n_leaves; // or return something special and let Tomato use leaf_labels_?
+ }
py::buffer_info cbuf = children.request();
if ((cbuf.ndim != 2 || cbuf.shape[1] != 2) && (cbuf.ndim != 1 || cbuf.shape[0] != 0))
throw std::runtime_error("internal error: children have to be (n,2) or empty");
const int n_merges = cbuf.shape[0];
Cluster_index* d = (Cluster_index*)cbuf.ptr;
- // Should this really be an error?
- // std::cerr << "n_merges: " << n_merges << ", n_final: " << n_final << ", n_leaves: " << n_leaves << '\n';
- if (n_merges + n_final < n_leaves)
- throw std::runtime_error(std::string("The number of clusters required ") + std::to_string(n_final) +
- " is smaller than the number of connected components " +
- std::to_string(n_leaves - n_merges));
-
+ if (n_merges + n_final < n_leaves) {
+ std::cerr << "The number of clusters required " << n_final << " is smaller than the number of connected components " << n_leaves - n_merges << '\n';
+ n_final = n_leaves - n_merges;
+ }
struct Dat {
Cluster_index parent;
int rank;
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 7e97819b..867c46a1 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -24,11 +24,11 @@ class Tomato:
cluster labels for each point, at the very bottom of the hierarchy
labels_: ndarray of shape (n_samples,)
cluster labels for each point, after merging
- diagram_: ndarray of shape (n_leaves_,2)
+ diagram_: ndarray of shape (n_leaves_, 2)
persistence diagram (only the finite points)
max_weight_per_cc_: ndarray of shape (n_connected_components,)
maximum of the density function on each connected component. This corresponds to the abscissa of infinite points in the diagram
- children_: ndarray of shape (n_leaves_-1,2)
+ children_: ndarray of shape (n_leaves_-n_connected_components, 2)
The children of each non-leaf node. Values less than n_leaves_ correspond to leaves of the tree. A node i greater than or equal to n_leaves_ is a non-leaf node and has children children_[i - n_leaves_]. Alternatively at the i-th iteration, children[i][0] and children[i][1] are merged to form node n_leaves_ + i
weights_: ndarray of shape (n_samples,)
weights of the points, as computed by the density estimator or provided by the user
@@ -206,6 +206,7 @@ class Tomato:
)
self.n_leaves_ = len(self.max_weight_per_cc_) + len(self.children_)
assert self.leaf_labels_.max() + 1 == len(self.max_weight_per_cc_) + len(self.children_)
+ # TODO: deduplicate this code with the setters below
if self.__merge_threshold:
assert not self.__n_clusters
self.__n_clusters = numpy.count_nonzero(
@@ -215,6 +216,9 @@ class Tomato:
# TODO: set corresponding merge_threshold?
renaming = merge(self.children_, self.n_leaves_, self.__n_clusters)
self.labels_ = renaming[self.leaf_labels_]
+ # In case the user asked for something impossible.
+ # TODO: check for impossible situations before calling merge.
+ self.__n_clusters = self.labels_.max() + 1
else:
self.labels_ = self.leaf_labels_
self.__n_clusters = self.n_leaves_
@@ -269,6 +273,8 @@ class Tomato:
if hasattr(self, "leaf_labels_"):
renaming = merge(self.children_, self.n_leaves_, self.__n_clusters)
self.labels_ = renaming[self.leaf_labels_]
+ # In case the user asked for something impossible
+ self.__n_clusters = self.labels_.max() + 1
@property
def merge_threshold_(self):