diff options
author | wreise <wojciech.reise@epfl.ch> | 2022-10-12 15:55:22 +0200 |
---|---|---|
committer | wreise <wojciech.reise@epfl.ch> | 2022-10-12 15:55:22 +0200 |
commit | 4aac9e03c400bd43f237504cf4ff9d25f041e473 (patch) | |
tree | d0b64e09fdd439cf3b7d119887a07dadca82938d /src/python/gudhi/representations | |
parent | 059ff0c42a069c744ed121c948bc3d39b5cc7f10 (diff) |
Clean argpartition
Diffstat (limited to 'src/python/gudhi/representations')
-rw-r--r-- | src/python/gudhi/representations/vector_methods.py | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py index 5ea4ea48..3a91eccd 100644 --- a/src/python/gudhi/representations/vector_methods.py +++ b/src/python/gudhi/representations/vector_methods.py @@ -124,7 +124,7 @@ def _automatic_sample_range(sample_range, X, y): return sample_range -def trim_on_edges(x, are_endpoints_nan): +def _trim_on_edges(x, are_endpoints_nan): if are_endpoints_nan[0]: x = x[1:] if are_endpoints_nan[1]: @@ -159,7 +159,7 @@ class Landscape(BaseEstimator, TransformerMixin): """ self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y) self.im_range = np.linspace(self.sample_range[0], self.sample_range[1], self.new_resolution) - self.im_range = trim_on_edges(self.im_range, self.nan_in_range) + self.im_range = _trim_on_edges(self.im_range, self.nan_in_range) return self def transform(self, X): @@ -178,9 +178,17 @@ class Landscape(BaseEstimator, TransformerMixin): for diag in X: midpoints, heights = (diag[:, 0] + diag[:, 1]) / 2., (diag[:, 1] - diag[:, 0]) / 2. tent_functions = np.maximum(heights[None, :] - np.abs(x_values[:, None] - midpoints[None, :]), 0) - tent_functions.partition(diag.shape[0] - self.num_landscapes, axis=1) - landscapes = tent_functions[:, -self.num_landscapes:][:, ::-1].T - + n_points = diag.shape[0] + # Get indices of largest elements: can't take more than n_points - 1 (the last ones are in the right position) + argpartition = np.argpartition(-tent_functions, min(self.num_landscapes, n_points-1), axis=1) + landscapes = np.take_along_axis(tent_functions, argpartition, axis=1) + landscapes = landscapes[:, :min(self.num_landscapes, n_points)].T + + # Complete the array with zeros to get the right number of landscapes + if self.num_landscapes > n_points: + landscapes = np.concatenate([ + landscapes, np.zeros((self.num_landscapes-n_points, *landscapes.shape[1:])) + ], axis=0) landscapes = np.sqrt(2) * np.ravel(landscapes) Xfit.append(landscapes) @@ -225,7 +233,7 @@ class Silhouette(BaseEstimator, TransformerMixin): """ self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y) self.im_range = np.linspace(self.sample_range[0], self.sample_range[1], self.new_resolution) - self.im_range = trim_on_edges(self.im_range, self.nan_in_range) + self.im_range = _trim_on_edges(self.im_range, self.nan_in_range) return self def transform(self, X): |