summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-25 18:19:05 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-25 18:19:05 +0200
commitc6bc508fe2f101f37fb7e1a940f3869122f7da71 (patch)
treeca3c8dcea1ea5048954bb7ca59ca471421dcb500 /src/python/gudhi/clustering
parentee56ee7814367c8b7437c8a8d9a0be32877c3196 (diff)
Use radius as default KDE bandwidth
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r--src/python/gudhi/clustering/tomato.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 824b5544..7e67c7fd 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -57,7 +57,7 @@ class Tomato:
kde_params (dict): if density_type is 'KDE' or 'logKDE', additional parameters passed directly to sklearn.neighbors.KernelDensity.
k (int): number of neighbors for a knn graph (including the vertex itself). Defaults to 10.
k_DTM (int): number of neighbors for the DTM density estimation (including the vertex itself). Defaults to k.
- r (float): size of a neighborhood if graph_type is 'radius'.
+ r (float): size of a neighborhood if graph_type is 'radius'. Also used as default bandwidth in kde_params.
eps (float): (1+eps) approximation factor when computing distances (ignored in many cases).
n_clusters (int): number of clusters requested. Defaults to None, i.e. no merging occurs and we get the maximal number of clusters.
merge_threshold (float): minimum prominence of a cluster so it doesn't get merged.
@@ -174,9 +174,12 @@ class Tomato:
self.neighbors_ = [numpy.flatnonzero(l <= r) for l in X]
if self.density_type_ in {"KDE", "logKDE"}:
- assert graph_type != "manual" and metric != "precomputed", "Scikit-learn's KernelDensity requires point coordinates"
- kde_params = self.params_.get("kde_params", dict())
+ assert self.graph_type_ != "manual" and metric != "precomputed", "Scikit-learn's KernelDensity requires point coordinates"
+ kde_params = dict(self.params_.get("kde_params", dict()))
kde_params.setdefault("metric", metric)
+ r = self.params_.get("r")
+ if r is not None:
+ kde_params.setdefault("bandwidth", r)
from sklearn.neighbors import KernelDensity
weights = KernelDensity(**kde_params).fit(self.points_).score_samples(self.points_)
if self.density_type_ == "KDE":