diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-05-25 18:19:05 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-05-25 18:19:05 +0200 |
commit | c6bc508fe2f101f37fb7e1a940f3869122f7da71 (patch) | |
tree | ca3c8dcea1ea5048954bb7ca59ca471421dcb500 /src/python/gudhi/clustering | |
parent | ee56ee7814367c8b7437c8a8d9a0be32877c3196 (diff) |
Use radius as default KDE bandwidth
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r-- | src/python/gudhi/clustering/tomato.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index 824b5544..7e67c7fd 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -57,7 +57,7 @@ class Tomato: kde_params (dict): if density_type is 'KDE' or 'logKDE', additional parameters passed directly to sklearn.neighbors.KernelDensity. k (int): number of neighbors for a knn graph (including the vertex itself). Defaults to 10. k_DTM (int): number of neighbors for the DTM density estimation (including the vertex itself). Defaults to k. - r (float): size of a neighborhood if graph_type is 'radius'. + r (float): size of a neighborhood if graph_type is 'radius'. Also used as default bandwidth in kde_params. eps (float): (1+eps) approximation factor when computing distances (ignored in many cases). n_clusters (int): number of clusters requested. Defaults to None, i.e. no merging occurs and we get the maximal number of clusters. merge_threshold (float): minimum prominence of a cluster so it doesn't get merged. @@ -174,9 +174,12 @@ class Tomato: self.neighbors_ = [numpy.flatnonzero(l <= r) for l in X] if self.density_type_ in {"KDE", "logKDE"}: - assert graph_type != "manual" and metric != "precomputed", "Scikit-learn's KernelDensity requires point coordinates" - kde_params = self.params_.get("kde_params", dict()) + assert self.graph_type_ != "manual" and metric != "precomputed", "Scikit-learn's KernelDensity requires point coordinates" + kde_params = dict(self.params_.get("kde_params", dict())) kde_params.setdefault("metric", metric) + r = self.params_.get("r") + if r is not None: + kde_params.setdefault("bandwidth", r) from sklearn.neighbors import KernelDensity weights = KernelDensity(**kde_params).fit(self.points_).score_samples(self.points_) if self.density_type_ == "KDE": |