diff options
author | wreise <wojciech.reise@epfl.ch> | 2022-05-25 16:43:15 +0200 |
---|---|---|
committer | wreise <wojciech.reise@epfl.ch> | 2022-05-25 16:43:15 +0200 |
commit | 912156b36da1dce1f73f8d2a63cc18e67c173d54 (patch) | |
tree | 4daa44739d3db171f49c0ad32408e16700c58c9d /src/python/gudhi/representations/vector_methods.py | |
parent | e8d0cbc3311765900e098b472608dc40b84d07d8 (diff) |
Move the initialisation of the Genred method to the constructor of Silhouette
Diffstat (limited to 'src/python/gudhi/representations/vector_methods.py')
-rw-r--r-- | src/python/gudhi/representations/vector_methods.py | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py index 55dc2c5b..c250c98c 100644 --- a/src/python/gudhi/representations/vector_methods.py +++ b/src/python/gudhi/representations/vector_methods.py @@ -238,6 +238,15 @@ class Silhouette(BaseEstimator, TransformerMixin): self.weight, self.resolution, self.sample_range = weight, resolution, sample_range self.im_range = None + silhouette_formula = "normalized_weights * ReLU(heights - Abs(x_values - midpoints))" + variables = [ + "normalized_weights = Vi(1)", + "heights = Vi(1)", + "midpoints = Vi(1)", + "x_values = Vj(1)", + ] + self.silhouette = Genred(silhouette_formula, variables, reduction_op="Sum", axis=0) + def fit(self, X, y=None): """ Fit the Silhouette class on a list of persistence diagrams: if any of the values in **sample_range** is numpy.nan, replace it with the corresponding value computed on the given list of persistence diagrams. @@ -260,16 +269,6 @@ class Silhouette(BaseEstimator, TransformerMixin): Returns: numpy array with shape (number of diagrams) x (**resolution**): output persistence silhouettes. """ - - silhouette_formula = "normalized_weights * ReLU(heights - Abs(x_values - midpoints))" - variables = [ - "normalized_weights = Vi(1)", - "heights = Vi(1)", - "midpoints = Vi(1)", - "x_values = Vj(1)", - ] - silhouette = Genred(silhouette_formula, variables, reduction_op="Sum", axis=0) - silhouettes_list = [] x_values = self.im_range for i, diag in enumerate(X): @@ -278,7 +277,7 @@ class Silhouette(BaseEstimator, TransformerMixin): weights /= np.sum(weights) silhouettes_list.append( - np.sqrt(2) * silhouette(weights[:, None], heights[:, None], + np.sqrt(2) * self.silhouette(weights[:, None], heights[:, None], midpoints[:, None], x_values[:, None])[:, 0] ) |