diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-03-05 07:00:56 +0100 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-03-05 07:00:56 +0100 |
commit | 6c8e59dbece96aeacff53c36809c06c087835905 (patch) | |
tree | a95b9bc42b77bb5dd4c8f6eb6fed38a5496c6649 | |
parent | fa8c487b24b58797398ce3a93a95095b43de23f3 (diff) |
Fix merge_threshold
-rw-r--r-- | src/python/gudhi/clustering/tomato.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index 0d754a62..ab317beb 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -4,6 +4,7 @@ from ._tomato import * # The fit/predict interface is not so well suited... # TODO: option for a faster, weaker (probabilistic) knn + class Tomato: """ Clustering @@ -76,7 +77,7 @@ class Tomato: self.params_ = params self.__n_clusters = n_clusters self.__merge_threshold = merge_threshold - #self.eliminate_threshold_ = eliminate_threshold + # self.eliminate_threshold_ = eliminate_threshold if n_clusters and merge_threshold: raise ValueError("Cannot specify both a merge threshold and a number of clusters") @@ -300,6 +301,11 @@ class Tomato: ) self.n_leaves_ = len(self.max_density_per_cc_) + len(self.children_) assert self.leaf_labels_.max() + 1 == len(self.max_density_per_cc_) + len(self.children_) + if self.__merge_threshold: + assert not self.__n_clusters + self.__n_clusters = numpy.count_nonzero( + self.diagram_[:, 1] - self.diagram_[:, 0] > self.__merge_threshold + ) + len(self.max_density_per_cc_) if self.__n_clusters: renaming = merge(self.children_, self.n_leaves_, self.__n_clusters) self.labels_ = renaming[self.leaf_labels_] @@ -330,7 +336,7 @@ class Tomato: # def predict(self, X): # # X had better be the same as in fit() - # return labels_ + # return self.labels_ # Use set_params instead? @property @@ -356,8 +362,8 @@ class Tomato: if merge_threshold == self.__merge_threshold: return if hasattr(self, "leaf_labels_"): - self.n_clusters_ = numpy.count_nonzero(self.diagram_[1] - self.diagram_[0] > merge_threshold) + len( - max_density_per_cc_ + self.n_clusters_ = numpy.count_nonzero(self.diagram_[:, 1] - self.diagram_[:, 0] > merge_threshold) + len( + self.max_density_per_cc_ ) else: self.__n_clusters = None |