diff options
author | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-06-07 14:57:02 +0200 |
---|---|---|
committer | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-06-07 14:57:02 +0200 |
commit | b7de9c211e9cfe361aa7bba9be32b88570972c38 (patch) | |
tree | 3f868ec53e5323311865bee3c191d5f9bb47f8cd /src/python/doc/cubical_complex_user.rst | |
parent | 8813c23e4931e9c955dd0e89547133065429ae0d (diff) |
Improve documentation
Diffstat (limited to 'src/python/doc/cubical_complex_user.rst')
-rw-r--r-- | src/python/doc/cubical_complex_user.rst | 45 |
1 files changed, 33 insertions, 12 deletions
diff --git a/src/python/doc/cubical_complex_user.rst b/src/python/doc/cubical_complex_user.rst index ebecb592..3fd9fd84 100644 --- a/src/python/doc/cubical_complex_user.rst +++ b/src/python/doc/cubical_complex_user.rst @@ -173,10 +173,24 @@ explains how to represent sublevels sets of functions using cubical complexes. Scikit-learn like interface example ----------------------------------- +In this example, hand written digits are used as an input. +a TDA scikit-learn pipeline is constructed and is composed of: + +#. :class:`~gudhi.sklearn.cubical_persistence.CubicalPersistence` that builds a cubical complex from the inputs and + returns its persistence diagrams +#. :class:`~gudhi.representations.DiagramSelector` that removes non-finite persistence diagrams values +#. :class:`~gudhi.representations.PersistenceImage` that builds the persistence images from persistence diagrams +#. `SVC <https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html>`_ which is a scikit-learn support + vector classifier. + +This ML pipeline is trained to detect if the hand written digit is an '8' or not, thanks to the fact that an '8' has +two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected components in :math:`\mathbf{H}_0`. + .. code-block:: python # Standard scientific Python imports import numpy as np + # Standard scikit-learn imports from sklearn.datasets import fetch_openml from sklearn.pipeline import Pipeline @@ -188,25 +202,32 @@ Scikit-learn like interface example from gudhi.sklearn.cubical_persistence import CubicalPersistence from gudhi.representations import PersistenceImage, DiagramSelector - X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False) + X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False) # Target is: "is an eight ?" - y = (y == '8') * 1 - print('There are', np.sum(y), 'eights out of', len(y), 'numbers.') + y = (y == "8") * 1 + print("There are", np.sum(y), "eights out of", len(y), "numbers.") X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0) - pipe = Pipeline([('cub_pers', CubicalPersistence(persistence_dim = 0, dimensions=[28,28], n_jobs=-2)), - ('finite_diags', DiagramSelector(use=True, point_type="finite")), - ('pers_img', PersistenceImage(bandwidth=50, - weight=lambda x: x[1]**2, - im_range=[0,256,0,256], - resolution=[20,20])), - ('svc', SVC())]) + pipe = Pipeline( + [ + ("cub_pers", CubicalPersistence(persistence_dim=0, dimensions=[28, 28], n_jobs=-2)), + ("finite_diags", DiagramSelector(use=True, point_type="finite")), + ( + "pers_img", + PersistenceImage(bandwidth=50, weight=lambda x: x[1] ** 2, im_range=[0, 256, 0, 256], resolution=[20, 20]), + ), + ("svc", SVC()), + ] + ) + # Learn from the train subset + pipe.fit(X_train, y_train) + # Predict from the test subset predicted = pipe.predict(X_test) - print(f"Classification report for TDA pipeline {pipe}:\n" - f"{metrics.classification_report(y_test, predicted)}\n") + print(f"Classification report for TDA pipeline {pipe}:\n" f"{metrics.classification_report(y_test, predicted)}\n") + .. code-block:: none |