From c6bc508fe2f101f37fb7e1a940f3869122f7da71 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Mon, 25 May 2020 18:19:05 +0200 Subject: Use radius as default KDE bandwidth --- src/python/gudhi/clustering/tomato.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'src/python/gudhi/clustering/tomato.py') diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index 824b5544..7e67c7fd 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -57,7 +57,7 @@ class Tomato: kde_params (dict): if density_type is 'KDE' or 'logKDE', additional parameters passed directly to sklearn.neighbors.KernelDensity. k (int): number of neighbors for a knn graph (including the vertex itself). Defaults to 10. k_DTM (int): number of neighbors for the DTM density estimation (including the vertex itself). Defaults to k. - r (float): size of a neighborhood if graph_type is 'radius'. + r (float): size of a neighborhood if graph_type is 'radius'. Also used as default bandwidth in kde_params. eps (float): (1+eps) approximation factor when computing distances (ignored in many cases). n_clusters (int): number of clusters requested. Defaults to None, i.e. no merging occurs and we get the maximal number of clusters. merge_threshold (float): minimum prominence of a cluster so it doesn't get merged. @@ -174,9 +174,12 @@ class Tomato: self.neighbors_ = [numpy.flatnonzero(l <= r) for l in X] if self.density_type_ in {"KDE", "logKDE"}: - assert graph_type != "manual" and metric != "precomputed", "Scikit-learn's KernelDensity requires point coordinates" - kde_params = self.params_.get("kde_params", dict()) + assert self.graph_type_ != "manual" and metric != "precomputed", "Scikit-learn's KernelDensity requires point coordinates" + kde_params = dict(self.params_.get("kde_params", dict())) kde_params.setdefault("metric", metric) + r = self.params_.get("r") + if r is not None: + kde_params.setdefault("bandwidth", r) from sklearn.neighbors import KernelDensity weights = KernelDensity(**kde_params).fit(self.points_).score_samples(self.points_) if self.density_type_ == "KDE": -- cgit v1.2.3