summaryrefslogtreecommitdiff
path: root/src/python/gudhi/sklearn/post_processing.py
blob: 79276e1e7699373ded2fa4109e66ccb1fba82921 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
# Author(s):       Vincent Rouvreau
#
# Copyright (C) 2021 Inria
#
# Modification(s):
#   - YYYY/MM Author: Description of the modification

from sklearn.base import BaseEstimator, TransformerMixin

# joblib is required by scikit-learn
from joblib import Parallel, delayed

# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
# sequenceDiagram
#     USER->>DimensionSelector: fit_transform(<br/>[[array( H0(X0) ), array( H1(X0) ), ...],<br/> [array( H0(X1) ), array( H1(X1) ), ...],<br/> ...])
#     DimensionSelector->>thread1: _transform([array( H0(X0) ), array( H1(X0) )], ...)
#     DimensionSelector->>thread2: _transform([array( H0(X1) ), array( H1(X1) )], ...)
#     Note right of DimensionSelector: ...
#     thread1->>DimensionSelector: array( Hn(X0) )
#     thread2->>DimensionSelector: array( Hn(X1) )
#     Note right of DimensionSelector: ...
#     DimensionSelector->>USER: [array( Hn(X0) ), <br/> array( Hn(X1) ), <br/> ...]


class DimensionSelector(BaseEstimator, TransformerMixin):
    """
    This is a class to select persistence diagrams in a specific dimension.
    """

    def __init__(self, persistence_dimension=0, n_jobs=None):
        """
        Constructor for the DimensionSelector class.

        Parameters:
            persistence_dimension (int): The returned persistence diagrams dimension. Default value is `0`.
        """
        self.persistence_dimension = persistence_dimension
        self.n_jobs = n_jobs

    def fit(self, X, Y=None):
        """
        Nothing to be done, but useful when included in a scikit-learn Pipeline.
        """
        return self

    def transform(self, X, Y=None):
        """
        Select persistence diagrams from its dimension.

        Parameters:
            X (list of list of pairs): List of list of persistence pairs, i.e.
            `[[array( H0(X0) ), array( H1(X0) ), ...], [array( H0(X1) ), array( H1(X1) ), ...], ...]` 

        Returns:
            Persistence diagrams in a specific dimension, i.e.
            `[array( Hn(X0) ), array( Hn(X1), ...]`
        """

        return [persistence[self.persistence_dimension] for persistence in X]