summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering/tomato.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-03-09 15:19:48 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-03-09 15:19:48 +0100
commit6aa18ec3b382f045ba7b97c7cbd6462e1de892ef (patch)
tree037b5e9251f9373dc7fd5e5492988dac3924809f /src/python/gudhi/clustering/tomato.py
parent1bb20075bea223734dfbd0750e3d787f00388f29 (diff)
Make fit return self
Diffstat (limited to 'src/python/gudhi/clustering/tomato.py')
-rw-r--r--src/python/gudhi/clustering/tomato.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index e3d814d1..000fdf3d 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -312,13 +312,13 @@ class Tomato:
else:
self.labels_ = self.leaf_labels_
self.__n_clusters = self.n_leaves_
+ return self
def fit_predict(self, X, y=None, weights=None):
"""
Equivalent to fit(), and returns the `labels_`.
"""
- self.fit(X, y, weights)
- return self.labels_
+ return self.fit(X, y, weights).labels_
# TODO: add argument k or threshold? Have a version where you can click and it shows the line and the corresponding k?
def plot_diagram(self):