summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-03-09 13:39:27 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-03-09 13:39:27 +0100
commitcb265d9262da5beb233ffdec3694d2dd15a9f2fd (patch)
tree61736abd4f58d75248df761b8fd7fb7b1d9abe73 /src/python/gudhi/clustering
parente8851febb643821cb60023a6d4e8759236115f79 (diff)
Allow passing weights to fit_predict
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r--src/python/gudhi/clustering/tomato.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 19e6f6e9..fd8f0e98 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -313,10 +313,11 @@ class Tomato:
self.labels_ = self.leaf_labels_
self.__n_clusters = self.n_leaves_
- def fit_predict(self, X, y=None):
+ def fit_predict(self, X, y=None, weights=None):
"""
+ Equivalent to fit(), and returns the `labels_`.
"""
- self.fit(X)
+ self.fit(X, y, weights)
return self.labels_
# TODO: add argument k or threshold? Have a version where you can click and it shows the line and the corresponding k?