summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-03-15 10:56:44 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-03-15 10:56:44 +0100
commit5a87b00145499bb5afeadaef7dec476ae5f826d0 (patch)
tree23fe30f552431a9fe572c1dab03c07dd79df7965 /src/python/gudhi/clustering
parent9b70806578bcf65e1f91e3286eb0be4142371954 (diff)
sklearn's score_sample returns log(density)
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r--src/python/gudhi/clustering/tomato.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 000fdf3d..6906c5bb 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -54,8 +54,8 @@ class Tomato:
input_type (str): 'points', 'distance_matrix' or 'neighbors'.
metric (None|Callable): If None, use Minkowski of parameter p.
graph_type (str): 'manual', 'knn' or 'radius'. Ignored if input_type is 'neighbors'.
- density_type (str): 'manual', 'DTM', 'logDTM' or 'kde'.
- kde_params (dict): if density_type is 'kde', additional parameters passed directly to sklearn.neighbors.KernelDensity.
+ density_type (str): 'manual', 'DTM', 'logDTM', 'KDE' or 'logKDE'.
+ kde_params (dict): if density_type is 'KDE' or 'logKDE', additional parameters passed directly to sklearn.neighbors.KernelDensity.
k (int): number of neighbors for a knn graph (including the vertex itself). Defaults to 10.
k_DTM (int): number of neighbors for the DTM density estimation (including the vertex itself). Defaults to k.
r (float): size of a neighborhood if graph_type is 'radius'
@@ -280,13 +280,14 @@ class Tomato:
else:
weights = -numpy.log(weights)
- if self.density_type_ == "kde":
+ if self.density_type_ == "KDE" or self.density_type_ == "logKDE":
# FIXME: replace most assert with raise ValueError("blabla")
assert input_type == "points"
kde_params = self.params_.get("kde_params", dict())
from sklearn.neighbors import KernelDensity
-
weights = KernelDensity(**kde_params).fit(X).score_samples(X)
+ if self.density_type_ == "KDE":
+ weights = numpy.exp(weights)
if self.params_.get("symmetrize_graph"):
self.neighbors_ = [set(line) for line in self.neighbors_]