summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/vector_methods.py
diff options
context:
space:
mode:
authorwreise <wojciech.reise@epfl.ch>2022-10-12 15:55:22 +0200
committerwreise <wojciech.reise@epfl.ch>2022-10-12 15:55:22 +0200
commit4aac9e03c400bd43f237504cf4ff9d25f041e473 (patch)
treed0b64e09fdd439cf3b7d119887a07dadca82938d /src/python/gudhi/representations/vector_methods.py
parent059ff0c42a069c744ed121c948bc3d39b5cc7f10 (diff)
Clean argpartition
Diffstat (limited to 'src/python/gudhi/representations/vector_methods.py')
-rw-r--r--src/python/gudhi/representations/vector_methods.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index 5ea4ea48..3a91eccd 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -124,7 +124,7 @@ def _automatic_sample_range(sample_range, X, y):
return sample_range
-def trim_on_edges(x, are_endpoints_nan):
+def _trim_on_edges(x, are_endpoints_nan):
if are_endpoints_nan[0]:
x = x[1:]
if are_endpoints_nan[1]:
@@ -159,7 +159,7 @@ class Landscape(BaseEstimator, TransformerMixin):
"""
self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
self.im_range = np.linspace(self.sample_range[0], self.sample_range[1], self.new_resolution)
- self.im_range = trim_on_edges(self.im_range, self.nan_in_range)
+ self.im_range = _trim_on_edges(self.im_range, self.nan_in_range)
return self
def transform(self, X):
@@ -178,9 +178,17 @@ class Landscape(BaseEstimator, TransformerMixin):
for diag in X:
midpoints, heights = (diag[:, 0] + diag[:, 1]) / 2., (diag[:, 1] - diag[:, 0]) / 2.
tent_functions = np.maximum(heights[None, :] - np.abs(x_values[:, None] - midpoints[None, :]), 0)
- tent_functions.partition(diag.shape[0] - self.num_landscapes, axis=1)
- landscapes = tent_functions[:, -self.num_landscapes:][:, ::-1].T
-
+ n_points = diag.shape[0]
+ # Get indices of largest elements: can't take more than n_points - 1 (the last ones are in the right position)
+ argpartition = np.argpartition(-tent_functions, min(self.num_landscapes, n_points-1), axis=1)
+ landscapes = np.take_along_axis(tent_functions, argpartition, axis=1)
+ landscapes = landscapes[:, :min(self.num_landscapes, n_points)].T
+
+ # Complete the array with zeros to get the right number of landscapes
+ if self.num_landscapes > n_points:
+ landscapes = np.concatenate([
+ landscapes, np.zeros((self.num_landscapes-n_points, *landscapes.shape[1:]))
+ ], axis=0)
landscapes = np.sqrt(2) * np.ravel(landscapes)
Xfit.append(landscapes)
@@ -225,7 +233,7 @@ class Silhouette(BaseEstimator, TransformerMixin):
"""
self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
self.im_range = np.linspace(self.sample_range[0], self.sample_range[1], self.new_resolution)
- self.im_range = trim_on_edges(self.im_range, self.nan_in_range)
+ self.im_range = _trim_on_edges(self.im_range, self.nan_in_range)
return self
def transform(self, X):