summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-24 22:40:44 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-24 22:40:44 +0200
commit9ff9055a93b5bc5c402519bd0bc8c85bf97d6d84 (patch)
tree1d7c48e7a6a7c23f919066f764c779ea1f65d9c2 /src/python/gudhi/clustering
parent753198e6d70e761df0c92617dfef7b338de4ba82 (diff)
Move metric to params
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r--src/python/gudhi/clustering/tomato.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 1c586f4f..29f30481 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -41,7 +41,6 @@ class Tomato:
self,
graph_type="knn",
density_type="logDTM",
- metric="minkowski",
n_clusters=None,
merge_threshold=None,
# eliminate_threshold=None,
@@ -52,9 +51,9 @@ class Tomato:
Each parameter has a corresponding attribute, like self.merge_threshold_, that can be changed later.
Args:
- metric (str|Callable): Describes how to interpret the argument of `fit()`. Defaults to Minkowski of parameter p.
graph_type (str): 'manual', 'knn' or 'radius'.
density_type (str): 'manual', 'DTM', 'logDTM', 'KDE' or 'logKDE'.
+ metric (str|Callable): metric used when calculating the distance between instances in a feature array. Defaults to Minkowski of parameter p.
kde_params (dict): if density_type is 'KDE' or 'logKDE', additional parameters passed directly to sklearn.neighbors.KernelDensity.
k (int): number of neighbors for a knn graph (including the vertex itself). Defaults to 10.
k_DTM (int): number of neighbors for the DTM density estimation (including the vertex itself). Defaults to k.
@@ -71,7 +70,6 @@ class Tomato:
"""
# Should metric='precomputed' mean input_type='distance_matrix'?
# Should we be able to pass metric='minkowski' (what None does currently)?
- self.metric_ = metric
self.graph_type_ = graph_type
self.density_type_ = density_type
self.params_ = params
@@ -102,7 +100,7 @@ class Tomato:
# FIXME: uniformize "message 'option'" vs 'message "option"'
assert density_type == "manual", 'If graph_type is "manual", density_type must be as well'
else:
- metric = self.metric_
+ metric = self.params_.get("metric", "minkowski")
if metric != "precomputed":
self.points_ = X
@@ -163,6 +161,7 @@ class Tomato:
eps = self.params_.get("eps", 0)
self.neighbors_ = t.query_ball_tree(t, r=self.params_["r"], p=p, eps=eps)
+ # TODO: sklearn's NearestNeighbors can handle more metrics efficiently via its BallTree
elif metric != "precomputed":
from sklearn.metrics import pairwise_distances
@@ -176,8 +175,9 @@ class Tomato:
self.neighbors_ = [numpy.flatnonzero(l <= r) for l in X]
if self.density_type_ in {"KDE", "logKDE"}:
- assert metric not in ["neighbors", "precomputed"], "Scikit-learn's KernelDensity requires point coordinates"
+ assert graph_type != "manual" and metric != "precomputed", "Scikit-learn's KernelDensity requires point coordinates"
kde_params = self.params_.get("kde_params", dict())
+ kde_params.setdefault("metric", metric)
from sklearn.neighbors import KernelDensity
weights = KernelDensity(**kde_params).fit(self.points_).score_samples(self.points_)
if self.density_type_ == "KDE":