summaryrefslogtreecommitdiff
path: root/src/python/doc/cubical_complex_sklearn_itf_ref.rst
blob: 2fb8ec6a4f64152dbf24670e24429c0ade07a615 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
:orphan:

.. To get rid of WARNING: document isn't included in any toctree

Cubical complex persistence scikit-learn like interface
#######################################################

.. list-table::
   :width: 100%
   :header-rows: 0

   * - :Since: GUDHI 3.6.0
     - :License: MIT
     - :Requires: `Scikit-learn <installation.html#scikit-learn>`_

Cubical complex persistence scikit-learn like interface example
---------------------------------------------------------------

In this example, hand written digits are used as an input.
a TDA scikit-learn pipeline is constructed and is composed of:

#. :class:`~gudhi.sklearn.cubical_persistence.CubicalPersistence` that builds a cubical complex from the inputs and
   returns its persistence diagrams
#. :class:`~gudhi.representations.preprocessing.DiagramSelector` that removes non-finite persistence diagrams values
#. :class:`~gudhi.representations.vector_methods.PersistenceImage` that builds the persistence images from persistence diagrams
#. `SVC <https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html>`_ which is a scikit-learn support
   vector classifier.

This ML pipeline is trained to detect if the hand written digit is an '8' or not, thanks to the fact that an '8' has
two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected components in :math:`\mathbf{H}_0`.

.. code-block:: python

    # Standard scientific Python imports
    import numpy as np
    
    # Standard scikit-learn imports
    from sklearn.datasets import fetch_openml
    from sklearn.pipeline import Pipeline
    from sklearn.model_selection import train_test_split
    from sklearn.svm import SVC
    from sklearn import metrics
    
    # Import TDA pipeline requirements
    from gudhi.sklearn.cubical_persistence import CubicalPersistence
    from gudhi.representations import PersistenceImage, DiagramSelector
    
    X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
    
    # Target is: "is an eight ?"
    y = (y == "8") * 1
    print("There are", np.sum(y), "eights out of", len(y), "numbers.")
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
    pipe = Pipeline(
        [
            ("cub_pers", CubicalPersistence(homology_dimensions=0, newshape=[28, 28], n_jobs=-2)),
            # Or for multiple persistence dimension computation
            # ("cub_pers", CubicalPersistence(homology_dimensions=[0, 1], newshape=[28, 28], n_jobs=-2)),
            # ("H0_diags", DimensionSelector(index=0), # where index is the index in homology_dimensions array
            ("finite_diags", DiagramSelector(use=True, point_type="finite")),
            (
                "pers_img",
                PersistenceImage(bandwidth=50, weight=lambda x: x[1] ** 2, im_range=[0, 256, 0, 256], resolution=[20, 20]),
            ),
            ("svc", SVC()),
        ]
    )
    
    # Learn from the train subset
    pipe.fit(X_train, y_train)
    # Predict from the test subset
    predicted = pipe.predict(X_test)
    
    print(f"Classification report for TDA pipeline {pipe}:\n" f"{metrics.classification_report(y_test, predicted)}\n")

.. code-block:: none

    There are 6825 eights out of 70000 numbers.
    Classification report for TDA pipeline Pipeline(steps=[('cub_pers',
                     CubicalPersistence(newshape=[28, 28], n_jobs=-2)),
                    ('finite_diags', DiagramSelector(use=True)),
                    ('pers_img',
                     PersistenceImage(bandwidth=50, im_range=[0, 256, 0, 256],
                                      weight=<function <lambda> at 0x7f3e54137ae8>)),
                    ('svc', SVC())]):
                  precision    recall  f1-score   support

               0       0.97      0.99      0.98     25284
               1       0.92      0.68      0.78      2716

        accuracy                           0.96     28000
       macro avg       0.94      0.84      0.88     28000
    weighted avg       0.96      0.96      0.96     28000

Cubical complex persistence scikit-learn like interface reference
-----------------------------------------------------------------

.. autoclass:: gudhi.sklearn.cubical_persistence.CubicalPersistence
   :members:
   :special-members: __init__
   :show-inheritance: