summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/vector_methods.py
diff options
context:
space:
mode:
authorwreise <wojciech.reise@epfl.ch>2022-05-25 16:43:15 +0200
committerwreise <wojciech.reise@epfl.ch>2022-05-25 16:43:15 +0200
commit912156b36da1dce1f73f8d2a63cc18e67c173d54 (patch)
tree4daa44739d3db171f49c0ad32408e16700c58c9d /src/python/gudhi/representations/vector_methods.py
parente8d0cbc3311765900e098b472608dc40b84d07d8 (diff)
Move the initialisation of the Genred method to the constructor of Silhouette
Diffstat (limited to 'src/python/gudhi/representations/vector_methods.py')
-rw-r--r--src/python/gudhi/representations/vector_methods.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index 55dc2c5b..c250c98c 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -238,6 +238,15 @@ class Silhouette(BaseEstimator, TransformerMixin):
self.weight, self.resolution, self.sample_range = weight, resolution, sample_range
self.im_range = None
+ silhouette_formula = "normalized_weights * ReLU(heights - Abs(x_values - midpoints))"
+ variables = [
+ "normalized_weights = Vi(1)",
+ "heights = Vi(1)",
+ "midpoints = Vi(1)",
+ "x_values = Vj(1)",
+ ]
+ self.silhouette = Genred(silhouette_formula, variables, reduction_op="Sum", axis=0)
+
def fit(self, X, y=None):
"""
Fit the Silhouette class on a list of persistence diagrams: if any of the values in **sample_range** is numpy.nan, replace it with the corresponding value computed on the given list of persistence diagrams.
@@ -260,16 +269,6 @@ class Silhouette(BaseEstimator, TransformerMixin):
Returns:
numpy array with shape (number of diagrams) x (**resolution**): output persistence silhouettes.
"""
-
- silhouette_formula = "normalized_weights * ReLU(heights - Abs(x_values - midpoints))"
- variables = [
- "normalized_weights = Vi(1)",
- "heights = Vi(1)",
- "midpoints = Vi(1)",
- "x_values = Vj(1)",
- ]
- silhouette = Genred(silhouette_formula, variables, reduction_op="Sum", axis=0)
-
silhouettes_list = []
x_values = self.im_range
for i, diag in enumerate(X):
@@ -278,7 +277,7 @@ class Silhouette(BaseEstimator, TransformerMixin):
weights /= np.sum(weights)
silhouettes_list.append(
- np.sqrt(2) * silhouette(weights[:, None], heights[:, None],
+ np.sqrt(2) * self.silhouette(weights[:, None], heights[:, None],
midpoints[:, None], x_values[:, None])[:, 0]
)