summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-28 18:40:11 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-28 18:40:11 +0200
commit1218f4540d51859b9527d2dd436ea8c50c429d68 (patch)
treee48cf5343ade9b5a06775c92a425c6e257a5fba1 /src/python/gudhi/clustering
parent3a421e5e981a0d637bfc8c0cb9da66e8750e2a8c (diff)
Théo's comments
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r--src/python/gudhi/clustering/tomato.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index d5c5daac..76b6a3c0 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -103,8 +103,11 @@ class Tomato:
def fit(self, X, y=None, weights=None):
"""
Args:
- X ((n,d)-array of float|(n,n)-array of float|Sequence[Iterable[int]]): coordinates of the points, or distance matrix (full, not just a triangle) if metric is "precomputed", or list of neighbors for each point (points are represented by their index, starting from 0) if graph_type is "manual".
+ X ((n,d)-array of float|(n,n)-array of float|Sequence[Iterable[int]]): coordinates of the points,
+ or distance matrix (full, not just a triangle) if metric is "precomputed", or list of neighbors
+ for each point (points are represented by their index, starting from 0) if graph_type is "manual".
weights (ndarray of shape (n_samples)): if density_type is 'manual', a density estimate at each point
+ y: Not used, present here for API consistency with scikit-learn by convention.
"""
# TODO: First detect if this is a new call with the same data (only threshold changed?)
# TODO: less code duplication (subroutines?), less spaghetti, but don't compute neighbors twice if not needed. Clear error message for missing or contradictory parameters.
@@ -229,6 +232,7 @@ class Tomato:
self.neighbors_[j].add(i)
self.weights_ = weights
+ # This is where the main computation happens
self.leaf_labels_, self.children_, self.diagram_, self.max_weight_per_cc_ = hierarchy(self.neighbors_, weights)
self.n_leaves_ = len(self.max_weight_per_cc_) + len(self.children_)
assert self.leaf_labels_.max() + 1 == len(self.max_weight_per_cc_) + len(self.children_)
@@ -281,10 +285,6 @@ class Tomato:
)
plt.show()
- # def predict(self, X):
- # # X had better be the same as in fit()
- # return self.labels_
-
# Use set_params instead?
@property
def n_clusters_(self):