summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/vector_methods.py
diff options
context:
space:
mode:
authormartinroyer <16647869+martinroyer@users.noreply.github.com>2020-06-11 21:17:27 +0200
committerGitHub <noreply@github.com>2020-06-11 21:17:27 +0200
commitec1c3ad11aeb46a67926a615fd5c00fbc70b501e (patch)
tree905a83d496bfe043ff2d542caef3911176ddc304 /src/python/gudhi/representations/vector_methods.py
parenta90843c6bf5f7f05392c4262efb60e94ccfb0e48 (diff)
case n_centers = 1
Diffstat (limited to 'src/python/gudhi/representations/vector_methods.py')
-rw-r--r--src/python/gudhi/representations/vector_methods.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index 5a45f179..a576267c 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -672,9 +672,14 @@ class Atol(BaseEstimator, TransformerMixin):
measures_concat = np.concatenate(X)
self.quantiser.fit(X=measures_concat, sample_weight=sample_weight)
self.centers = self.quantiser.cluster_centers_
- dist_centers = pairwise.pairwise_distances(self.centers)
- np.fill_diagonal(dist_centers, np.inf)
- self.inertias = np.min(dist_centers, axis=0)/2
+ if self.quantiser.n_clusters == 1:
+ dist_centers = pairwise.pairwise_distances(measures_concat)
+ np.fill_diagonal(dist_centers, 0)
+ self.inertias = np.max(dist_centers)/2
+ else:
+ dist_centers = pairwise.pairwise_distances(self.centers)
+ np.fill_diagonal(dist_centers, np.inf)
+ self.inertias = np.min(dist_centers, axis=0)/2
return self
def __call__(self, measure, sample_weight=None):