summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorwreise <wojciech.reise@epfl.ch>2022-10-15 18:45:42 +0200
committerwreise <wojciech.reise@epfl.ch>2022-10-15 18:45:42 +0200
commitcd7dea8627f4b1c624e88d5ff28b32d1602f5e39 (patch)
treef7602049d89cfec1f3511f4875823a7d85a20e93
parent74617f0673aa13bce47833c366321a8838a7d123 (diff)
Treat the case when there are less points than landscape layers
-rw-r--r--src/python/gudhi/representations/vector_methods.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index 6267e077..a169aee8 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -179,14 +179,15 @@ class Landscape(BaseEstimator, TransformerMixin):
midpoints, heights = (diag[:, 0] + diag[:, 1]) / 2., (diag[:, 1] - diag[:, 0]) / 2.
tent_functions = np.maximum(heights[None, :] - np.abs(x_values[:, None] - midpoints[None, :]), 0)
n_points = diag.shape[0]
- tent_functions.partition(n_points-self.num_landscapes, axis=1)
- landscapes = np.sort(tent_functions[:, -self.num_landscapes:], axis=1)[:, ::-1].T
-
# Complete the array with zeros to get the right number of landscapes
if self.num_landscapes > n_points:
- landscapes = np.concatenate([
- landscapes, np.zeros((self.num_landscapes-n_points, *landscapes.shape[1:]))
- ], axis=0)
+ tent_functions = np.concatenate(
+ [tent_functions, np.zeros((tent_functions.shape[0], self.num_landscapes-n_points))],
+ axis=1
+ )
+ tent_functions.partition(tent_functions.shape[1]-self.num_landscapes, axis=1)
+ landscapes = np.sort(tent_functions[:, -self.num_landscapes:], axis=1)[:, ::-1].T
+
landscapes = np.sqrt(2) * np.ravel(landscapes)
Xfit.append(landscapes)