summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering/tomato.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-02-29 10:27:01 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-02-29 10:27:01 +0100
commit7063d82f2fc25ba2819adf7c2dbf430d4f012626 (patch)
treeaf56be7c5548f79a89765f2a9e86b1d9f688d473 /src/python/gudhi/clustering/tomato.py
parent23f0949dd204a4f4b0fec5527b64b5d5eabbebf8 (diff)
Reformat
Diffstat (limited to 'src/python/gudhi/clustering/tomato.py')
-rw-r--r--src/python/gudhi/clustering/tomato.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 467dd17e..dc004bde 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -38,8 +38,8 @@ class Tomato:
density_type="DTM",
n_clusters=None,
merge_threshold=None,
-# eliminate_threshold=None,
-# eliminate_threshold (float): minimum max weight of a cluster so it doesn't get eliminated
+ # eliminate_threshold=None,
+ # eliminate_threshold (float): minimum max weight of a cluster so it doesn't get eliminated
**params
):
"""
@@ -93,8 +93,9 @@ class Tomato:
input_type = self.input_type_
if input_type == "points" and self.metric_:
from sklearn.metrics import pairwise_distances
- X = pairwise_distances(X,metric=self.metric_,n_jobs=self.params_.get("n_jobs"))
- input_type="distance_matrix"
+
+ X = pairwise_distances(X, metric=self.metric_, n_jobs=self.params_.get("n_jobs"))
+ input_type = "distance_matrix"
if input_type == "distance_matrix" and self.graph_type_ == "radius":
X = numpy.array(X)
@@ -266,7 +267,7 @@ class Tomato:
if not k:
k = self.params_["k"]
q = self.params_.get("p_DTM", 2)
- weights = (numpy.partition(X, k - 1)[:,0:k] ** q).sum(-1)
+ weights = (numpy.partition(X, k - 1)[:, 0:k] ** q).sum(-1)
if self.density_type_ == "DTM":
dim = len(self.points_[0])
weights = weights ** (-dim / q)