From 9ff9055a93b5bc5c402519bd0bc8c85bf97d6d84 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Sun, 24 May 2020 22:40:44 +0200 Subject: Move metric to params --- src/python/gudhi/clustering/tomato.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'src/python/gudhi/clustering/tomato.py') diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index 1c586f4f..29f30481 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -41,7 +41,6 @@ class Tomato: self, graph_type="knn", density_type="logDTM", - metric="minkowski", n_clusters=None, merge_threshold=None, # eliminate_threshold=None, @@ -52,9 +51,9 @@ class Tomato: Each parameter has a corresponding attribute, like self.merge_threshold_, that can be changed later. Args: - metric (str|Callable): Describes how to interpret the argument of `fit()`. Defaults to Minkowski of parameter p. graph_type (str): 'manual', 'knn' or 'radius'. density_type (str): 'manual', 'DTM', 'logDTM', 'KDE' or 'logKDE'. + metric (str|Callable): metric used when calculating the distance between instances in a feature array. Defaults to Minkowski of parameter p. kde_params (dict): if density_type is 'KDE' or 'logKDE', additional parameters passed directly to sklearn.neighbors.KernelDensity. k (int): number of neighbors for a knn graph (including the vertex itself). Defaults to 10. k_DTM (int): number of neighbors for the DTM density estimation (including the vertex itself). Defaults to k. @@ -71,7 +70,6 @@ class Tomato: """ # Should metric='precomputed' mean input_type='distance_matrix'? # Should we be able to pass metric='minkowski' (what None does currently)? - self.metric_ = metric self.graph_type_ = graph_type self.density_type_ = density_type self.params_ = params @@ -102,7 +100,7 @@ class Tomato: # FIXME: uniformize "message 'option'" vs 'message "option"' assert density_type == "manual", 'If graph_type is "manual", density_type must be as well' else: - metric = self.metric_ + metric = self.params_.get("metric", "minkowski") if metric != "precomputed": self.points_ = X @@ -163,6 +161,7 @@ class Tomato: eps = self.params_.get("eps", 0) self.neighbors_ = t.query_ball_tree(t, r=self.params_["r"], p=p, eps=eps) + # TODO: sklearn's NearestNeighbors can handle more metrics efficiently via its BallTree elif metric != "precomputed": from sklearn.metrics import pairwise_distances @@ -176,8 +175,9 @@ class Tomato: self.neighbors_ = [numpy.flatnonzero(l <= r) for l in X] if self.density_type_ in {"KDE", "logKDE"}: - assert metric not in ["neighbors", "precomputed"], "Scikit-learn's KernelDensity requires point coordinates" + assert graph_type != "manual" and metric != "precomputed", "Scikit-learn's KernelDensity requires point coordinates" kde_params = self.params_.get("kde_params", dict()) + kde_params.setdefault("metric", metric) from sklearn.neighbors import KernelDensity weights = KernelDensity(**kde_params).fit(self.points_).score_samples(self.points_) if self.density_type_ == "KDE": -- cgit v1.2.3