summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/preprocessing.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/representations/preprocessing.py')
-rw-r--r--src/python/gudhi/representations/preprocessing.py57
1 files changed, 53 insertions, 4 deletions
diff --git a/src/python/gudhi/representations/preprocessing.py b/src/python/gudhi/representations/preprocessing.py
index a8545349..8722e162 100644
--- a/src/python/gudhi/representations/preprocessing.py
+++ b/src/python/gudhi/representations/preprocessing.py
@@ -1,10 +1,11 @@
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
-# Author(s): Mathieu Carrière
+# Author(s): Mathieu Carrière, Vincent Rouvreau
#
# Copyright (C) 2018-2019 Inria
#
# Modification(s):
+# - 2021/10 Vincent Rouvreau: Add DimensionSelector
# - YYYY/MM Author: Description of the modification
import numpy as np
@@ -75,7 +76,7 @@ class Clamping(BaseEstimator, TransformerMixin):
Constructor for the Clamping class.
Parameters:
- limit (double): clamping value (default np.inf).
+ limit (float): clamping value (default np.inf).
"""
self.minimum = minimum
self.maximum = maximum
@@ -234,7 +235,7 @@ class ProminentPoints(BaseEstimator, TransformerMixin):
use (bool): whether to use the class or not (default False).
location (string): either "upper" or "lower" (default "upper"). Whether to keep the points that are far away ("upper") or close ("lower") to the diagonal.
num_pts (int): cardinality threshold (default 10). If location == "upper", keep the top **num_pts** points that are the farthest away from the diagonal. If location == "lower", keep the top **num_pts** points that are the closest to the diagonal.
- threshold (double): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal.
+ threshold (float): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal.
"""
self.num_pts = num_pts
self.threshold = threshold
@@ -317,7 +318,7 @@ class DiagramSelector(BaseEstimator, TransformerMixin):
Parameters:
use (bool): whether to use the class or not (default False).
- limit (double): second coordinate value that is the criterion for being an essential point (default numpy.inf).
+ limit (float): second coordinate value that is the criterion for being an essential point (default numpy.inf).
point_type (string): either "finite" or "essential". The type of the points that are going to be extracted.
"""
self.use, self.limit, self.point_type = use, limit, point_type
@@ -363,3 +364,51 @@ class DiagramSelector(BaseEstimator, TransformerMixin):
n x 2 numpy array: extracted persistence diagram.
"""
return self.fit_transform([diag])[0]
+
+
+# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
+# sequenceDiagram
+# USER->>DimensionSelector: fit_transform(<br/>[[array( Hi(X0) ), array( Hj(X0) ), ...],<br/> [array( Hi(X1) ), array( Hj(X1) ), ...],<br/> ...])
+# DimensionSelector->>thread1: _transform([array( Hi(X0) ), array( Hj(X0) )], ...)
+# DimensionSelector->>thread2: _transform([array( Hi(X1) ), array( Hj(X1) )], ...)
+# Note right of DimensionSelector: ...
+# thread1->>DimensionSelector: array( Hn(X0) )
+# thread2->>DimensionSelector: array( Hn(X1) )
+# Note right of DimensionSelector: ...
+# DimensionSelector->>USER: [array( Hn(X0) ), <br/> array( Hn(X1) ), <br/> ...]
+
+class DimensionSelector(BaseEstimator, TransformerMixin):
+ """
+ This is a class to select persistence diagrams in a specific dimension from its index.
+ """
+
+ def __init__(self, index=0):
+ """
+ Constructor for the DimensionSelector class.
+
+ Parameters:
+ index (int): The returned persistence diagrams dimension index. Default value is `0`.
+ """
+ self.index = index
+
+ def fit(self, X, Y=None):
+ """
+ Nothing to be done, but useful when included in a scikit-learn Pipeline.
+ """
+ return self
+
+ def transform(self, X, Y=None):
+ """
+ Select persistence diagrams from its dimension.
+
+ Parameters:
+ X (list of list of tuple): List of list of persistence pairs, i.e.
+ `[[array( Hi(X0) ), array( Hj(X0) ), ...], [array( Hi(X1) ), array( Hj(X1) ), ...], ...]`
+
+ Returns:
+ list of tuple:
+ Persistence diagrams in a specific dimension. i.e. if `index` was set to `m` and `Hn` is at index `m` of
+ the input, it returns `[array( Hn(X0) ), array( Hn(X1), ...]`
+ """
+
+ return [persistence[self.index] for persistence in X]