diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-03-15 10:56:44 +0100 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-03-15 10:56:44 +0100 |
commit | 5a87b00145499bb5afeadaef7dec476ae5f826d0 (patch) | |
tree | 23fe30f552431a9fe572c1dab03c07dd79df7965 | |
parent | 9b70806578bcf65e1f91e3286eb0be4142371954 (diff) |
sklearn's score_sample returns log(density)
-rw-r--r-- | src/python/gudhi/clustering/tomato.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index 000fdf3d..6906c5bb 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -54,8 +54,8 @@ class Tomato: input_type (str): 'points', 'distance_matrix' or 'neighbors'. metric (None|Callable): If None, use Minkowski of parameter p. graph_type (str): 'manual', 'knn' or 'radius'. Ignored if input_type is 'neighbors'. - density_type (str): 'manual', 'DTM', 'logDTM' or 'kde'. - kde_params (dict): if density_type is 'kde', additional parameters passed directly to sklearn.neighbors.KernelDensity. + density_type (str): 'manual', 'DTM', 'logDTM', 'KDE' or 'logKDE'. + kde_params (dict): if density_type is 'KDE' or 'logKDE', additional parameters passed directly to sklearn.neighbors.KernelDensity. k (int): number of neighbors for a knn graph (including the vertex itself). Defaults to 10. k_DTM (int): number of neighbors for the DTM density estimation (including the vertex itself). Defaults to k. r (float): size of a neighborhood if graph_type is 'radius' @@ -280,13 +280,14 @@ class Tomato: else: weights = -numpy.log(weights) - if self.density_type_ == "kde": + if self.density_type_ == "KDE" or self.density_type_ == "logKDE": # FIXME: replace most assert with raise ValueError("blabla") assert input_type == "points" kde_params = self.params_.get("kde_params", dict()) from sklearn.neighbors import KernelDensity - weights = KernelDensity(**kde_params).fit(X).score_samples(X) + if self.density_type_ == "KDE": + weights = numpy.exp(weights) if self.params_.get("symmetrize_graph"): self.neighbors_ = [set(line) for line in self.neighbors_] |