summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/python/gudhi/clustering/tomato.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 19e6f6e9..fd8f0e98 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -313,10 +313,11 @@ class Tomato:
self.labels_ = self.leaf_labels_
self.__n_clusters = self.n_leaves_
- def fit_predict(self, X, y=None):
+ def fit_predict(self, X, y=None, weights=None):
"""
+ Equivalent to fit(), and returns the `labels_`.
"""
- self.fit(X)
+ self.fit(X, y, weights)
return self.labels_
# TODO: add argument k or threshold? Have a version where you can click and it shows the line and the corresponding k?