From 1218f4540d51859b9527d2dd436ea8c50c429d68 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Thu, 28 May 2020 18:40:11 +0200 Subject: Théo's comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/python/gudhi/clustering/tomato.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index d5c5daac..76b6a3c0 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -103,8 +103,11 @@ class Tomato: def fit(self, X, y=None, weights=None): """ Args: - X ((n,d)-array of float|(n,n)-array of float|Sequence[Iterable[int]]): coordinates of the points, or distance matrix (full, not just a triangle) if metric is "precomputed", or list of neighbors for each point (points are represented by their index, starting from 0) if graph_type is "manual". + X ((n,d)-array of float|(n,n)-array of float|Sequence[Iterable[int]]): coordinates of the points, + or distance matrix (full, not just a triangle) if metric is "precomputed", or list of neighbors + for each point (points are represented by their index, starting from 0) if graph_type is "manual". weights (ndarray of shape (n_samples)): if density_type is 'manual', a density estimate at each point + y: Not used, present here for API consistency with scikit-learn by convention. """ # TODO: First detect if this is a new call with the same data (only threshold changed?) # TODO: less code duplication (subroutines?), less spaghetti, but don't compute neighbors twice if not needed. Clear error message for missing or contradictory parameters. @@ -229,6 +232,7 @@ class Tomato: self.neighbors_[j].add(i) self.weights_ = weights + # This is where the main computation happens self.leaf_labels_, self.children_, self.diagram_, self.max_weight_per_cc_ = hierarchy(self.neighbors_, weights) self.n_leaves_ = len(self.max_weight_per_cc_) + len(self.children_) assert self.leaf_labels_.max() + 1 == len(self.max_weight_per_cc_) + len(self.children_) @@ -281,10 +285,6 @@ class Tomato: ) plt.show() - # def predict(self, X): - # # X had better be the same as in fit() - # return self.labels_ - # Use set_params instead? @property def n_clusters_(self): -- cgit v1.2.3