From 2b896ce68eb5cf99d698313ca0e9eea3b35a19c6 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Mon, 18 May 2020 22:15:07 +0200 Subject: Start refactoring of Tomato --- src/python/gudhi/clustering/tomato.py | 48 ++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 4 deletions(-) (limited to 'src/python/gudhi/clustering') diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index 0a2d562b..fa462f8f 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -1,4 +1,6 @@ import numpy +from ..point_cloud.knn import KNearestNeighbors +from ..point_cloud.dtm import DTMDensity from ._tomato import * # The fit/predict interface is not so well suited... @@ -37,6 +39,7 @@ class Tomato: def __init__( self, + # FIXME: fold input_type into metric input_type="points", metric=None, graph_type="knn", @@ -65,7 +68,8 @@ class Tomato: merge_threshold (float): minimum prominence of a cluster so it doesn't get merged. symmetrize_graph (bool): whether we should add edges to make the neighborhood graph symmetric. This can be useful with k-NN for small k. Defaults to false. p (float): norm L^p on input points (numpy.inf is supported without gpu). Defaults to 2. - p_DTM (float): order used to compute the distance to measure. Defaults to the dimension, or 2 if input_type is 'distance_matrix'. + q (float): order used to compute the distance to measure. Defaults to dim. Beware that when the dimension is large, this can easily cause overflows. + dim (float): final exponent in DTM density estimation, representing the dimension. Defaults to the dimension, or 2 when the dimension cannot be read from the input (metric is "neighbors" or "precomputed"). n_jobs (int): Number of jobs to schedule for parallel processing of nearest neighbors on the CPU. If -1 is given all processors are used. Default: 1. """ # Should metric='precomputed' mean input_type='distance_matrix'? @@ -99,12 +103,48 @@ class Tomato: input_type = self.input_type_ if input_type == "points": self.points_ = X + + # FIXME: restrict this strongly if input_type == "points" and self.metric_: from sklearn.metrics import pairwise_distances X = pairwise_distances(X, metric=self.metric_, n_jobs=self.params_.get("n_jobs")) input_type = "distance_matrix" + need_knn = 0 + need_knn_ngb = False + need_knn_dist = False + if self.graph_type_ == "knn": + k_graph = self.params_["k"] + need_knn = k_graph + need_knn_ngb = True + if elf.density_type_ in ["DTM", "logDTM"]: + k = self.params_.get("k", 10) # FIXME: What if X has fewer than 10 points? + k_DTM = self.params_.get("k_DTM", k) + need_knn = max(need_knn, k_DTM) + need_knn_dist = True + # if we ask for more neighbors for the graph than the DTM, getting the distances is a slight waste, + # but it looks negligible + if need_knn > 0: + knn = KNearestNeighbors(need_knn, return_index=need_knn_ngb, return_distance=need_knn_dist, **self.params_).fit_transform(X) + if need_knn_ngb: + if need_knn_dist: + knn_ngb = knn[0][:, 0:k_graph] + knn_dist = knn[1] + else: + knn_ngb = knn + if need_knn_dist: + knn_dist = knn + if self.density_type_ in ["DTM", "logDTM"]: + if metric in ["neighbors", "precomputed"]: + dim = self.params_.get("dim", 2) + else: + dim = len(X) + q = self.params_.get("q", dim) + weights = DTMDensity(k=k_DTM, metric="neighbors", dim=dim, q=q).fit_transform(knn_dist) + if self.density_type_ == "logDTM": + weights = numpy.log(weights) + if input_type == "distance_matrix" and self.graph_type_ == "radius": X = numpy.array(X) r = self.params_["r"] @@ -119,7 +159,7 @@ class Tomato: assert density_type == "manual" if input_type == "points" and self.graph_type_ == "knn" and self.density_type_ in {"DTM", "logDTM"}: - q = self.params_.get("p_DTM", len(X[0])) + q = self.params_.get("q", len(X[0])) p = self.params_.get("p", 2) k = self.params_.get("k", 10) k_DTM = self.params_.get("k_DTM", k) @@ -217,7 +257,7 @@ class Tomato: _, self.neighbors_ = kdtree.query(self.points_, k=k, p=p, **qargs) if input_type == "points" and self.graph_type_ != "knn" and self.density_type_ in {"DTM", "logDTM"}: - q = self.params_.get("p_DTM", len(X[0])) + q = self.params_.get("q", len(X[0])) p = self.params_.get("p", 2) k = self.params_.get("k", 10) k_DTM = self.params_.get("k_DTM", k) @@ -265,7 +305,7 @@ class Tomato: weights = -numpy.log(weights) if input_type == "distance_matrix" and self.density_type_ in {"DTM", "logDTM"}: - q = self.params_.get("p_DTM", 2) + q = self.params_.get("q", 2) X = numpy.array(X) k = self.params_.get("k_DTM") if not k: -- cgit v1.2.3