summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/vector_methods.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/representations/vector_methods.py')
-rw-r--r--src/python/gudhi/representations/vector_methods.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index 5a45f179..a576267c 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -672,9 +672,14 @@ class Atol(BaseEstimator, TransformerMixin):
measures_concat = np.concatenate(X)
self.quantiser.fit(X=measures_concat, sample_weight=sample_weight)
self.centers = self.quantiser.cluster_centers_
- dist_centers = pairwise.pairwise_distances(self.centers)
- np.fill_diagonal(dist_centers, np.inf)
- self.inertias = np.min(dist_centers, axis=0)/2
+ if self.quantiser.n_clusters == 1:
+ dist_centers = pairwise.pairwise_distances(measures_concat)
+ np.fill_diagonal(dist_centers, 0)
+ self.inertias = np.max(dist_centers)/2
+ else:
+ dist_centers = pairwise.pairwise_distances(self.centers)
+ np.fill_diagonal(dist_centers, np.inf)
+ self.inertias = np.min(dist_centers, axis=0)/2
return self
def __call__(self, measure, sample_weight=None):