diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-02-29 10:27:01 +0100 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-02-29 10:27:01 +0100 |
commit | 7063d82f2fc25ba2819adf7c2dbf430d4f012626 (patch) | |
tree | af56be7c5548f79a89765f2a9e86b1d9f688d473 /src/python/gudhi | |
parent | 23f0949dd204a4f4b0fec5527b64b5d5eabbebf8 (diff) |
Reformat
Diffstat (limited to 'src/python/gudhi')
-rw-r--r-- | src/python/gudhi/clustering/tomato.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index 467dd17e..dc004bde 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -38,8 +38,8 @@ class Tomato: density_type="DTM", n_clusters=None, merge_threshold=None, -# eliminate_threshold=None, -# eliminate_threshold (float): minimum max weight of a cluster so it doesn't get eliminated + # eliminate_threshold=None, + # eliminate_threshold (float): minimum max weight of a cluster so it doesn't get eliminated **params ): """ @@ -93,8 +93,9 @@ class Tomato: input_type = self.input_type_ if input_type == "points" and self.metric_: from sklearn.metrics import pairwise_distances - X = pairwise_distances(X,metric=self.metric_,n_jobs=self.params_.get("n_jobs")) - input_type="distance_matrix" + + X = pairwise_distances(X, metric=self.metric_, n_jobs=self.params_.get("n_jobs")) + input_type = "distance_matrix" if input_type == "distance_matrix" and self.graph_type_ == "radius": X = numpy.array(X) @@ -266,7 +267,7 @@ class Tomato: if not k: k = self.params_["k"] q = self.params_.get("p_DTM", 2) - weights = (numpy.partition(X, k - 1)[:,0:k] ** q).sum(-1) + weights = (numpy.partition(X, k - 1)[:, 0:k] ** q).sum(-1) if self.density_type_ == "DTM": dim = len(self.points_[0]) weights = weights ** (-dim / q) |