diff options
author | martinroyer <16647869+martinroyer@users.noreply.github.com> | 2020-06-11 21:17:27 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-06-11 21:17:27 +0200 |
commit | ec1c3ad11aeb46a67926a615fd5c00fbc70b501e (patch) | |
tree | 905a83d496bfe043ff2d542caef3911176ddc304 /src/python | |
parent | a90843c6bf5f7f05392c4262efb60e94ccfb0e48 (diff) |
case n_centers = 1
Diffstat (limited to 'src/python')
-rw-r--r-- | src/python/gudhi/representations/vector_methods.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py index 5a45f179..a576267c 100644 --- a/src/python/gudhi/representations/vector_methods.py +++ b/src/python/gudhi/representations/vector_methods.py @@ -672,9 +672,14 @@ class Atol(BaseEstimator, TransformerMixin): measures_concat = np.concatenate(X) self.quantiser.fit(X=measures_concat, sample_weight=sample_weight) self.centers = self.quantiser.cluster_centers_ - dist_centers = pairwise.pairwise_distances(self.centers) - np.fill_diagonal(dist_centers, np.inf) - self.inertias = np.min(dist_centers, axis=0)/2 + if self.quantiser.n_clusters == 1: + dist_centers = pairwise.pairwise_distances(measures_concat) + np.fill_diagonal(dist_centers, 0) + self.inertias = np.max(dist_centers)/2 + else: + dist_centers = pairwise.pairwise_distances(self.centers) + np.fill_diagonal(dist_centers, np.inf) + self.inertias = np.min(dist_centers, axis=0)/2 return self def __call__(self, measure, sample_weight=None): |