diff options
Diffstat (limited to 'src/python/gudhi/representations/preprocessing.py')
-rw-r--r-- | src/python/gudhi/representations/preprocessing.py | 57 |
1 files changed, 53 insertions, 4 deletions
diff --git a/src/python/gudhi/representations/preprocessing.py b/src/python/gudhi/representations/preprocessing.py index a8545349..bd8c2774 100644 --- a/src/python/gudhi/representations/preprocessing.py +++ b/src/python/gudhi/representations/preprocessing.py @@ -1,10 +1,11 @@ # This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. # See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. -# Author(s): Mathieu Carrière +# Author(s): Mathieu Carrière, Vincent Rouvreau # # Copyright (C) 2018-2019 Inria # # Modification(s): +# - 2021/10 Vincent Rouvreau: Add DimensionSelector # - YYYY/MM Author: Description of the modification import numpy as np @@ -75,7 +76,7 @@ class Clamping(BaseEstimator, TransformerMixin): Constructor for the Clamping class. Parameters: - limit (double): clamping value (default np.inf). + limit (float): clamping value (default np.inf). """ self.minimum = minimum self.maximum = maximum @@ -234,7 +235,7 @@ class ProminentPoints(BaseEstimator, TransformerMixin): use (bool): whether to use the class or not (default False). location (string): either "upper" or "lower" (default "upper"). Whether to keep the points that are far away ("upper") or close ("lower") to the diagonal. num_pts (int): cardinality threshold (default 10). If location == "upper", keep the top **num_pts** points that are the farthest away from the diagonal. If location == "lower", keep the top **num_pts** points that are the closest to the diagonal. - threshold (double): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal. + threshold (float): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal. """ self.num_pts = num_pts self.threshold = threshold @@ -317,7 +318,7 @@ class DiagramSelector(BaseEstimator, TransformerMixin): Parameters: use (bool): whether to use the class or not (default False). - limit (double): second coordinate value that is the criterion for being an essential point (default numpy.inf). + limit (float): second coordinate value that is the criterion for being an essential point (default numpy.inf). point_type (string): either "finite" or "essential". The type of the points that are going to be extracted. """ self.use, self.limit, self.point_type = use, limit, point_type @@ -363,3 +364,51 @@ class DiagramSelector(BaseEstimator, TransformerMixin): n x 2 numpy array: extracted persistence diagram. """ return self.fit_transform([diag])[0] + + +# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/ +# sequenceDiagram +# USER->>DimensionSelector: fit_transform(<br/>[[array( Hi(X0) ), array( Hj(X0) ), ...],<br/> [array( Hi(X1) ), array( Hj(X1) ), ...],<br/> ...]) +# DimensionSelector->>thread1: _transform([array( Hi(X0) ), array( Hj(X0) )], ...) +# DimensionSelector->>thread2: _transform([array( Hi(X1) ), array( Hj(X1) )], ...) +# Note right of DimensionSelector: ... +# thread1->>DimensionSelector: array( Hn(X0) ) +# thread2->>DimensionSelector: array( Hn(X1) ) +# Note right of DimensionSelector: ... +# DimensionSelector->>USER: [array( Hn(X0) ), <br/> array( Hn(X1) ), <br/> ...] + +class DimensionSelector(BaseEstimator, TransformerMixin): + """ + This is a class to select persistence diagrams in a specific dimension from its index. + """ + + def __init__(self, index=0): + """ + Constructor for the DimensionSelector class. + + Parameters: + index (int): The returned persistence diagrams dimension index. Default value is `0`. + """ + self.index = index + + def fit(self, X, Y=None): + """ + Nothing to be done, but useful when included in a scikit-learn Pipeline. + """ + return self + + def transform(self, X, Y=None): + """ + Select persistence diagrams from its dimension. + + Parameters: + X (list of list of pairs): List of list of persistence pairs, i.e. + `[[array( Hi(X0) ), array( Hj(X0) ), ...], [array( Hi(X1) ), array( Hj(X1) ), ...], ...]` + + Returns: + list of pairs: + Persistence diagrams in a specific dimension. i.e. if `index` was set to `m` and `Hn` is at index `n` of + the input, it returns `[array( Hn(X0) ), array( Hn(X1), ...]` + """ + + return [persistence[self.index] for persistence in X] |