summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-02-28 23:45:24 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-02-28 23:45:24 +0100
commit23f0949dd204a4f4b0fec5527b64b5d5eabbebf8 (patch)
treed153c698c315221fca7b841168c87d55e8d0dff6 /src/python/gudhi/clustering
parentf3b7d742580b3c93f8d1d70952a7809f6f52ca80 (diff)
metric==Callable
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r--src/python/gudhi/clustering/tomato.py32
1 files changed, 19 insertions, 13 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 5257e487..467dd17e 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -34,11 +34,12 @@ class Tomato:
self,
input_type="points",
metric=None,
- graph_type=None,
- density_type="manual",
+ graph_type="knn",
+ density_type="DTM",
n_clusters=None,
merge_threshold=None,
- eliminate_threshold=None,
+# eliminate_threshold=None,
+# eliminate_threshold (float): minimum max weight of a cluster so it doesn't get eliminated
**params
):
"""
@@ -46,7 +47,7 @@ class Tomato:
Args:
input_type (str): 'points', 'distance_matrix' or 'neighbors'.
- metric (str or callable): FIXME ???
+ metric (None|Callable): If None, use Minkowski of parameter p.
graph_type (str): 'manual', 'knn' or 'radius'. Ignored if input_type is 'neighbors'.
density_type (str): 'manual', 'DTM', 'logDTM' or 'kde'.
kde_params (dict): if density_type is 'kde', additional parameters passed directly to sklearn.neighbors.KernelDensity.
@@ -57,7 +58,6 @@ class Tomato:
gpu (bool): enable use of CUDA (through pykeops) to compute k nearest neighbors. This is useful when the dimension becomes large (10+) but the number of points remains low (less than a million).
n_clusters (int): number of clusters requested. Defaults to ???
merge_threshold (float): minimum prominence of a cluster so it doesn't get merged.
- eliminate_threshold (float): minimum height of a cluster so it doesn't get eliminated
symmetrize_graph (bool): whether we should add edges to make the neighborhood graph symmetric. This can be useful with k-NN for small k. Defaults to false.
p (float): norm L^p on input points (numpy.inf is supported without gpu). Defaults to 2.
p_DTM (float): order used to compute the distance to measure. Defaults to 2.
@@ -90,20 +90,26 @@ class Tomato:
if density_type == "manual":
raise ValueError("If density_type is 'manual', you must provide weights to fit()")
- if self.input_type_ == "distance_matrix" and self.graph_type_ == "radius":
+ input_type = self.input_type_
+ if input_type == "points" and self.metric_:
+ from sklearn.metrics import pairwise_distances
+ X = pairwise_distances(X,metric=self.metric_,n_jobs=self.params_.get("n_jobs"))
+ input_type="distance_matrix"
+
+ if input_type == "distance_matrix" and self.graph_type_ == "radius":
X = numpy.array(X)
r = self.params_["r"]
self.neighbors_ = [numpy.nonzero(l <= r) for l in X]
- if self.input_type_ == "distance_matrix" and self.graph_type_ == "knn":
+ if input_type == "distance_matrix" and self.graph_type_ == "knn":
k = self.params_["k"]
self.neighbors_ = numpy.argpartition(X, k - 1)[:, 0:k]
- if self.input_type_ == "neighbors":
+ if input_type == "neighbors":
self.neighbors_ = X
assert density_type == "manual"
- if self.input_type_ == "points" and self.graph_type_ == "knn" and self.density_type_ in {"DTM", "logDTM"}:
+ if input_type == "points" and self.graph_type_ == "knn" and self.density_type_ in {"DTM", "logDTM"}:
self.points_ = X
q = self.params_.get("p_DTM", 2)
p = self.params_.get("p", 2)
@@ -174,7 +180,7 @@ class Tomato:
# We ignore exponents, which become constant factors with log
weights = -numpy.log(weights)
- if self.input_type_ == "points" and self.graph_type_ == "knn" and self.density_type_ not in {"DTM", "logDTM"}:
+ if input_type == "points" and self.graph_type_ == "knn" and self.density_type_ not in {"DTM", "logDTM"}:
self.points_ = X
p = self.params_.get("p", 2)
k = self.params_.get("k", 10)
@@ -203,7 +209,7 @@ class Tomato:
qargs = {k: v for k, v in self.params_.items() if k in {"eps", "n_jobs"}}
_, self.neighbors_ = kdtree.query(self.points_, k=k, p=p, **qargs)
- if self.input_type_ == "points" and self.graph_type_ != "knn" and self.density_type_ in {"DTM", "logDTM"}:
+ if input_type == "points" and self.graph_type_ != "knn" and self.density_type_ in {"DTM", "logDTM"}:
self.points_ = X
q = self.params_.get("p_DTM", 2)
p = self.params_.get("p", 2)
@@ -253,7 +259,7 @@ class Tomato:
else:
weights = -numpy.log(weights)
- if self.input_type_ == "distance_matrix" and self.density_type_ in {"DTM", "logDTM"}:
+ if input_type == "distance_matrix" and self.density_type_ in {"DTM", "logDTM"}:
q = self.params_.get("p_DTM", 2)
X = numpy.array(X)
k = self.params_.get("k_DTM")
@@ -269,7 +275,7 @@ class Tomato:
if self.density_type_ == "kde":
# FIXME: replace most assert with raise ValueError("blabla")
- assert self.input_type_ == "points"
+ assert input_type == "points"
kde_params = self.params_.get("kde_params", dict())
from sklearn.neighbors import KernelDensity