summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/vector_methods.py
diff options
context:
space:
mode:
authorwreise <wojciech.reise@epfl.ch>2022-08-05 22:19:30 +0200
committerwreise <wojciech.reise@epfl.ch>2022-08-05 22:19:30 +0200
commit60e57f9c86a7aae67c2931200066aba059ec2721 (patch)
treed2ee408a1710356de1267833a28f8b94a1e588b5 /src/python/gudhi/representations/vector_methods.py
parent42b18e60e418f4078cd6406dcc202b696798c844 (diff)
Test the numpy version
Diffstat (limited to 'src/python/gudhi/representations/vector_methods.py')
-rw-r--r--src/python/gudhi/representations/vector_methods.py42
1 files changed, 11 insertions, 31 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index b0843120..7f311b3b 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -140,16 +140,6 @@ class Landscape(BaseEstimator, TransformerMixin):
self.nan_in_range = np.isnan(np.array(self.sample_range))
self.new_resolution = self.resolution + self.nan_in_range.sum()
- landscape_formula = "(-1)*ReLU(heights - Abs(x_values - midpoints))"
- variables = [
- "heights = Vi(1)",
- "midpoints = Vi(1)",
- "x_values = Vj(1)",
- ]
- from pykeops.numpy import Genred
- self.landscape = Genred(landscape_formula, variables, reduction_op="KMin",
- axis=0, opt_arg=self.num_landscapes)
-
def fit(self, X, y=None):
"""
Fit the Landscape class on a list of persistence diagrams: if any of the values in **sample_range** is numpy.nan, replace it with the corresponding value computed on the given list of persistence diagrams.
@@ -178,13 +168,13 @@ class Landscape(BaseEstimator, TransformerMixin):
midpoints, heights = (diag[:, 0] + diag[:, 1]) / 2., (diag[:, 1] - diag[:, 0]) / 2.
tent_functions = np.maximum(heights[None, :] - np.abs(x_values[:, None] - midpoints[None, :]), 0)
tent_functions.partition(diag.shape[0] - self.num_landscapes, axis=1)
- landscapes = np.sort(tent_functions[-self.num_landscapes:, :])[::-1].T
+ landscapes = np.sort(tent_functions, axis=1)[:, -self.num_landscapes:][:, ::-1].T
if self.nan_in_range[0]:
- landscapes = landscapes[:,1:]
+ landscapes = landscapes[:, 1:]
if self.nan_in_range[1]:
- landscapes = landscapes[:,:-1]
- landscapes = np.sqrt(2)*np.reshape(landscapes, [1, -1])
+ landscapes = landscapes[:, :-1]
+ landscapes = np.sqrt(2) * np.ravel(landscapes)
Xfit.append(landscapes)
return np.stack(Xfit, axis=0)
@@ -217,16 +207,6 @@ class Silhouette(BaseEstimator, TransformerMixin):
self.weight, self.resolution, self.sample_range = weight, resolution, sample_range
self.im_range = None
- from pykeops.numpy import Genred
- silhouette_formula = "normalized_weights * ReLU(heights - Abs(x_values - midpoints))"
- variables = [
- "normalized_weights = Vi(1)",
- "heights = Vi(1)",
- "midpoints = Vi(1)",
- "x_values = Vj(1)",
- ]
- self.silhouette = Genred(silhouette_formula, variables, reduction_op="Sum", axis=0)
-
def fit(self, X, y=None):
"""
Fit the Silhouette class on a list of persistence diagrams: if any of the values in **sample_range** is numpy.nan, replace it with the corresponding value computed on the given list of persistence diagrams.
@@ -249,19 +229,19 @@ class Silhouette(BaseEstimator, TransformerMixin):
Returns:
numpy array with shape (number of diagrams) x (**resolution**): output persistence silhouettes.
"""
- silhouettes_list = []
+ Xfit = []
x_values = self.im_range
+
for i, diag in enumerate(X):
midpoints, heights = (diag[:, 0] + diag[:, 1]) / 2., (diag[:, 1] - diag[:, 0]) / 2.
weights = np.array([self.weight(point) for point in diag])
- weights /= np.sum(weights)
+ total_weight = np.sum(weights)
- silhouettes_list.append(
- np.sqrt(2) * self.silhouette(weights[:, None], heights[:, None],
- midpoints[:, None], x_values[:, None])[:, 0]
- )
+ tent_functions = np.maximum(heights[None, :] - np.abs(x_values[:, None] - midpoints[None, :]), 0)
+ silhouette = np.sum(weights[None, :] / total_weight * tent_functions, axis=1)
+ Xfit.append(silhouette * np.sqrt(2))
- return np.stack(silhouettes_list, axis=0)
+ return np.stack(Xfit, axis=0)
def __call__(self, diag):
"""