From 8c122a8c92285dd89844720c9cf04d001db491d0 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Mon, 25 May 2020 17:32:17 +0200 Subject: bugs --- src/python/gudhi/clustering/tomato.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'src/python/gudhi/clustering') diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index 18425700..2b4d9242 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -132,10 +132,9 @@ class Tomato: elif need_knn_dist: knn_dist = knn if self.density_type_ in ["DTM", "logDTM"]: - if metric == "precomputed": - dim = self.params_.get("dim", 2) - else: - dim = len(X) + dim = self.params_.get("dim") + if dim is None: + dim = len(X[0]) if metric != "precomputed" else 2 q = self.params_.get("q", dim) weights = DTMDensity(k=k_DTM, metric="neighbors", dim=dim, q=q).fit_transform(knn_dist) if self.density_type_ == "logDTM": @@ -144,7 +143,7 @@ class Tomato: if self.graph_type_ == "radius": if metric in ["minkowski", "euclidean", "manhattan", "chebyshev"]: from scipy.spatial import cKDTree - tree t = cKDTree(X) + tree = cKDTree(X) # TODO: handle "l1" and "l2" aliases? p = self.params_.get("p") if metric == "euclidean": @@ -159,7 +158,7 @@ class Tomato: elif p is None: p = 2 # the default eps = self.params_.get("eps", 0) - self.neighbors_ = t.query_ball_tree(t, r=self.params_["r"], p=p, eps=eps) + self.neighbors_ = tree.query_ball_tree(tree, r=self.params_["r"], p=p, eps=eps) # TODO: sklearn's NearestNeighbors.radius_neighbors can handle more metrics efficiently via its BallTree (don't bother with the _graph variant, it just calls radius_neighbors). elif metric != "precomputed": -- cgit v1.2.3