From b84b3f006805eb69e03983301a550ddcb8050769 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Tue, 28 Apr 2020 20:55:57 +0200 Subject: Always save points_ --- src/python/gudhi/clustering/tomato.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'src/python/gudhi/clustering/tomato.py') diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index 5f1f8e24..07006d7c 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -98,6 +98,8 @@ class Tomato: raise ValueError("If density_type is 'manual', you must provide weights to fit()") input_type = self.input_type_ + if input_type == "points": + self.points_ = X if input_type == "points" and self.metric_: from sklearn.metrics import pairwise_distances @@ -118,7 +120,6 @@ class Tomato: assert density_type == "manual" if input_type == "points" and self.graph_type_ == "knn" and self.density_type_ in {"DTM", "logDTM"}: - self.points_ = X q = self.params_.get("p_DTM", len(X[0])) p = self.params_.get("p", 2) k = self.params_.get("k", 10) @@ -189,7 +190,6 @@ class Tomato: weights = -numpy.log(weights) if input_type == "points" and self.graph_type_ == "knn" and self.density_type_ not in {"DTM", "logDTM"}: - self.points_ = X p = self.params_.get("p", 2) k = self.params_.get("k", 10) if self.params_.get("gpu"): @@ -218,7 +218,6 @@ class Tomato: _, self.neighbors_ = kdtree.query(self.points_, k=k, p=p, **qargs) if input_type == "points" and self.graph_type_ != "knn" and self.density_type_ in {"DTM", "logDTM"}: - self.points_ = X q = self.params_.get("p_DTM", len(X[0])) p = self.params_.get("p", 2) k = self.params_.get("k", 10) @@ -275,7 +274,10 @@ class Tomato: k = self.params_["k"] weights = (numpy.partition(X, k - 1)[:, 0:k] ** q).sum(-1) if self.density_type_ == "DTM": - dim = len(self.points_[0]) + try: + dim = len(self.points_[0]) + except AttributeError: + dim = 2 weights = weights ** (-dim / q) else: weights = -numpy.log(weights) -- cgit v1.2.3