summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-18 22:15:07 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-18 22:15:07 +0200
commit2b896ce68eb5cf99d698313ca0e9eea3b35a19c6 (patch)
treeb8e5e22de8b777a1f30b8b629d9c97e3ce43fdda /src/python/gudhi/clustering
parent05415ffc3e1e8b00623de2088093d84e3030b0f1 (diff)
Start refactoring of Tomato
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r--src/python/gudhi/clustering/tomato.py48
1 files changed, 44 insertions, 4 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 0a2d562b..fa462f8f 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -1,4 +1,6 @@
import numpy
+from ..point_cloud.knn import KNearestNeighbors
+from ..point_cloud.dtm import DTMDensity
from ._tomato import *
# The fit/predict interface is not so well suited...
@@ -37,6 +39,7 @@ class Tomato:
def __init__(
self,
+ # FIXME: fold input_type into metric
input_type="points",
metric=None,
graph_type="knn",
@@ -65,7 +68,8 @@ class Tomato:
merge_threshold (float): minimum prominence of a cluster so it doesn't get merged.
symmetrize_graph (bool): whether we should add edges to make the neighborhood graph symmetric. This can be useful with k-NN for small k. Defaults to false.
p (float): norm L^p on input points (numpy.inf is supported without gpu). Defaults to 2.
- p_DTM (float): order used to compute the distance to measure. Defaults to the dimension, or 2 if input_type is 'distance_matrix'.
+ q (float): order used to compute the distance to measure. Defaults to dim. Beware that when the dimension is large, this can easily cause overflows.
+ dim (float): final exponent in DTM density estimation, representing the dimension. Defaults to the dimension, or 2 when the dimension cannot be read from the input (metric is "neighbors" or "precomputed").
n_jobs (int): Number of jobs to schedule for parallel processing of nearest neighbors on the CPU. If -1 is given all processors are used. Default: 1.
"""
# Should metric='precomputed' mean input_type='distance_matrix'?
@@ -99,12 +103,48 @@ class Tomato:
input_type = self.input_type_
if input_type == "points":
self.points_ = X
+
+ # FIXME: restrict this strongly
if input_type == "points" and self.metric_:
from sklearn.metrics import pairwise_distances
X = pairwise_distances(X, metric=self.metric_, n_jobs=self.params_.get("n_jobs"))
input_type = "distance_matrix"
+ need_knn = 0
+ need_knn_ngb = False
+ need_knn_dist = False
+ if self.graph_type_ == "knn":
+ k_graph = self.params_["k"]
+ need_knn = k_graph
+ need_knn_ngb = True
+ if elf.density_type_ in ["DTM", "logDTM"]:
+ k = self.params_.get("k", 10) # FIXME: What if X has fewer than 10 points?
+ k_DTM = self.params_.get("k_DTM", k)
+ need_knn = max(need_knn, k_DTM)
+ need_knn_dist = True
+ # if we ask for more neighbors for the graph than the DTM, getting the distances is a slight waste,
+ # but it looks negligible
+ if need_knn > 0:
+ knn = KNearestNeighbors(need_knn, return_index=need_knn_ngb, return_distance=need_knn_dist, **self.params_).fit_transform(X)
+ if need_knn_ngb:
+ if need_knn_dist:
+ knn_ngb = knn[0][:, 0:k_graph]
+ knn_dist = knn[1]
+ else:
+ knn_ngb = knn
+ if need_knn_dist:
+ knn_dist = knn
+ if self.density_type_ in ["DTM", "logDTM"]:
+ if metric in ["neighbors", "precomputed"]:
+ dim = self.params_.get("dim", 2)
+ else:
+ dim = len(X)
+ q = self.params_.get("q", dim)
+ weights = DTMDensity(k=k_DTM, metric="neighbors", dim=dim, q=q).fit_transform(knn_dist)
+ if self.density_type_ == "logDTM":
+ weights = numpy.log(weights)
+
if input_type == "distance_matrix" and self.graph_type_ == "radius":
X = numpy.array(X)
r = self.params_["r"]
@@ -119,7 +159,7 @@ class Tomato:
assert density_type == "manual"
if input_type == "points" and self.graph_type_ == "knn" and self.density_type_ in {"DTM", "logDTM"}:
- q = self.params_.get("p_DTM", len(X[0]))
+ q = self.params_.get("q", len(X[0]))
p = self.params_.get("p", 2)
k = self.params_.get("k", 10)
k_DTM = self.params_.get("k_DTM", k)
@@ -217,7 +257,7 @@ class Tomato:
_, self.neighbors_ = kdtree.query(self.points_, k=k, p=p, **qargs)
if input_type == "points" and self.graph_type_ != "knn" and self.density_type_ in {"DTM", "logDTM"}:
- q = self.params_.get("p_DTM", len(X[0]))
+ q = self.params_.get("q", len(X[0]))
p = self.params_.get("p", 2)
k = self.params_.get("k", 10)
k_DTM = self.params_.get("k_DTM", k)
@@ -265,7 +305,7 @@ class Tomato:
weights = -numpy.log(weights)
if input_type == "distance_matrix" and self.density_type_ in {"DTM", "logDTM"}:
- q = self.params_.get("p_DTM", 2)
+ q = self.params_.get("q", 2)
X = numpy.array(X)
k = self.params_.get("k_DTM")
if not k: