summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-21 20:00:33 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-21 20:00:33 +0200
commit8cd34b4f0a6f5ffe8cfc16d2bb5856e5f6400216 (patch)
tree30d9a83f1975a092c484e4bebf038c364d92aa13 /src/python/gudhi/clustering
parent8fff4339aab542c29e8672c720152198c6647615 (diff)
Random fixes
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r--src/python/gudhi/clustering/tomato.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 5bfb9f68..75296ab3 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -118,7 +118,7 @@ class Tomato:
k_graph = self.params_["k"]
need_knn = k_graph
need_knn_ngb = True
- if elf.density_type_ in ["DTM", "logDTM"]:
+ if self.density_type_ in ["DTM", "logDTM"]:
k = self.params_.get("k", 10) # FIXME: What if X has fewer than 10 points?
k_DTM = self.params_.get("k_DTM", k)
need_knn = max(need_knn, k_DTM)
@@ -126,20 +126,22 @@ class Tomato:
# if we ask for more neighbors for the graph than the DTM, getting the distances is a slight waste,
# but it looks negligible
if need_knn > 0:
- knn = KNearestNeighbors(need_knn, return_index=need_knn_ngb, return_distance=need_knn_dist, **self.params_).fit_transform(X)
+ knn_args = dict(self.params_)
+ knn_args["k"] = need_knn
+ knn = KNearestNeighbors(return_index=need_knn_ngb, return_distance=need_knn_dist, **knn_args).fit_transform(X)
if need_knn_ngb:
if need_knn_dist:
self.neighbors_ = knn[0][:, 0:k_graph]
knn_dist = knn[1]
else:
self.neighbors_ = knn
- if need_knn_dist:
+ elif need_knn_dist:
knn_dist = knn
if self.density_type_ in ["DTM", "logDTM"]:
- if metric in ["neighbors", "precomputed"]:
+ if self.metric_ in ["neighbors", "precomputed"]:
dim = self.params_.get("dim", 2)
else:
- dim = len(X)
+ dim = len(X) # FIXME for distance matrix
q = self.params_.get("q", dim)
weights = DTMDensity(k=k_DTM, metric="neighbors", dim=dim, q=q).fit_transform(knn_dist)
if self.density_type_ == "logDTM":