summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-03-05 07:00:56 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-03-05 07:00:56 +0100
commit6c8e59dbece96aeacff53c36809c06c087835905 (patch)
treea95b9bc42b77bb5dd4c8f6eb6fed38a5496c6649 /src/python/gudhi/clustering
parentfa8c487b24b58797398ce3a93a95095b43de23f3 (diff)
Fix merge_threshold
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r--src/python/gudhi/clustering/tomato.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 0d754a62..ab317beb 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -4,6 +4,7 @@ from ._tomato import *
# The fit/predict interface is not so well suited...
# TODO: option for a faster, weaker (probabilistic) knn
+
class Tomato:
"""
Clustering
@@ -76,7 +77,7 @@ class Tomato:
self.params_ = params
self.__n_clusters = n_clusters
self.__merge_threshold = merge_threshold
- #self.eliminate_threshold_ = eliminate_threshold
+ # self.eliminate_threshold_ = eliminate_threshold
if n_clusters and merge_threshold:
raise ValueError("Cannot specify both a merge threshold and a number of clusters")
@@ -300,6 +301,11 @@ class Tomato:
)
self.n_leaves_ = len(self.max_density_per_cc_) + len(self.children_)
assert self.leaf_labels_.max() + 1 == len(self.max_density_per_cc_) + len(self.children_)
+ if self.__merge_threshold:
+ assert not self.__n_clusters
+ self.__n_clusters = numpy.count_nonzero(
+ self.diagram_[:, 1] - self.diagram_[:, 0] > self.__merge_threshold
+ ) + len(self.max_density_per_cc_)
if self.__n_clusters:
renaming = merge(self.children_, self.n_leaves_, self.__n_clusters)
self.labels_ = renaming[self.leaf_labels_]
@@ -330,7 +336,7 @@ class Tomato:
# def predict(self, X):
# # X had better be the same as in fit()
- # return labels_
+ # return self.labels_
# Use set_params instead?
@property
@@ -356,8 +362,8 @@ class Tomato:
if merge_threshold == self.__merge_threshold:
return
if hasattr(self, "leaf_labels_"):
- self.n_clusters_ = numpy.count_nonzero(self.diagram_[1] - self.diagram_[0] > merge_threshold) + len(
- max_density_per_cc_
+ self.n_clusters_ = numpy.count_nonzero(self.diagram_[:, 1] - self.diagram_[:, 0] > merge_threshold) + len(
+ self.max_density_per_cc_
)
else:
self.__n_clusters = None