summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-07 14:57:02 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-07 14:57:02 +0200
commitb7de9c211e9cfe361aa7bba9be32b88570972c38 (patch)
tree3f868ec53e5323311865bee3c191d5f9bb47f8cd
parent8813c23e4931e9c955dd0e89547133065429ae0d (diff)
Improve documentation
-rw-r--r--src/python/doc/cubical_complex_user.rst45
-rw-r--r--src/python/gudhi/sklearn/cubical_persistence.py18
2 files changed, 48 insertions, 15 deletions
diff --git a/src/python/doc/cubical_complex_user.rst b/src/python/doc/cubical_complex_user.rst
index ebecb592..3fd9fd84 100644
--- a/src/python/doc/cubical_complex_user.rst
+++ b/src/python/doc/cubical_complex_user.rst
@@ -173,10 +173,24 @@ explains how to represent sublevels sets of functions using cubical complexes.
Scikit-learn like interface example
-----------------------------------
+In this example, hand written digits are used as an input.
+a TDA scikit-learn pipeline is constructed and is composed of:
+
+#. :class:`~gudhi.sklearn.cubical_persistence.CubicalPersistence` that builds a cubical complex from the inputs and
+ returns its persistence diagrams
+#. :class:`~gudhi.representations.DiagramSelector` that removes non-finite persistence diagrams values
+#. :class:`~gudhi.representations.PersistenceImage` that builds the persistence images from persistence diagrams
+#. `SVC <https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html>`_ which is a scikit-learn support
+ vector classifier.
+
+This ML pipeline is trained to detect if the hand written digit is an '8' or not, thanks to the fact that an '8' has
+two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected components in :math:`\mathbf{H}_0`.
+
.. code-block:: python
# Standard scientific Python imports
import numpy as np
+
# Standard scikit-learn imports
from sklearn.datasets import fetch_openml
from sklearn.pipeline import Pipeline
@@ -188,25 +202,32 @@ Scikit-learn like interface example
from gudhi.sklearn.cubical_persistence import CubicalPersistence
from gudhi.representations import PersistenceImage, DiagramSelector
- X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)
+ X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
# Target is: "is an eight ?"
- y = (y == '8') * 1
- print('There are', np.sum(y), 'eights out of', len(y), 'numbers.')
+ y = (y == "8") * 1
+ print("There are", np.sum(y), "eights out of", len(y), "numbers.")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
- pipe = Pipeline([('cub_pers', CubicalPersistence(persistence_dim = 0, dimensions=[28,28], n_jobs=-2)),
- ('finite_diags', DiagramSelector(use=True, point_type="finite")),
- ('pers_img', PersistenceImage(bandwidth=50,
- weight=lambda x: x[1]**2,
- im_range=[0,256,0,256],
- resolution=[20,20])),
- ('svc', SVC())])
+ pipe = Pipeline(
+ [
+ ("cub_pers", CubicalPersistence(persistence_dim=0, dimensions=[28, 28], n_jobs=-2)),
+ ("finite_diags", DiagramSelector(use=True, point_type="finite")),
+ (
+ "pers_img",
+ PersistenceImage(bandwidth=50, weight=lambda x: x[1] ** 2, im_range=[0, 256, 0, 256], resolution=[20, 20]),
+ ),
+ ("svc", SVC()),
+ ]
+ )
+ # Learn from the train subset
+ pipe.fit(X_train, y_train)
+ # Predict from the test subset
predicted = pipe.predict(X_test)
- print(f"Classification report for TDA pipeline {pipe}:\n"
- f"{metrics.classification_report(y_test, predicted)}\n")
+ print(f"Classification report for TDA pipeline {pipe}:\n" f"{metrics.classification_report(y_test, predicted)}\n")
+
.. code-block:: none
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py
index f4341bf6..251e240f 100644
--- a/src/python/gudhi/sklearn/cubical_persistence.py
+++ b/src/python/gudhi/sklearn/cubical_persistence.py
@@ -1,3 +1,12 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Vincent Rouvreau
+#
+# Copyright (C) 2021 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
from .. import CubicalComplex
from sklearn.base import BaseEstimator, TransformerMixin
@@ -17,7 +26,8 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
Constructor for the CubicalPersistence class.
Parameters:
- dimensions (list of int): A list of number of top dimensional cells.
+ dimensions (list of int): A list of number of top dimensional cells if cells filtration values will require
+ to be reshaped (cf. :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform`)
persistence_dim (int): The returned persistence diagrams dimension. Default value is `0`.
min_persistence (float): The minimum persistence value to take into account (strictly greater than
`min_persistence`). Default value is `0.0`. Sets `min_persistence` to `-1.0` to see all values.
@@ -39,7 +49,7 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
def fit(self, X, Y=None):
"""
- Nothing to be done.
+ Nothing to be done, but useful when included in a scikit-learn Pipeline.
"""
return self
@@ -56,7 +66,9 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
Compute all the cubical complexes and their associated persistence diagrams.
Parameters:
- X (list of list of double OR list of numpy.ndarray): List of cells filtration values.
+ X (list of list of double OR list of numpy.ndarray): List of cells filtration values that can be flatten if
+ dimensions is set in the constructor, or already with the correct shape in a numpy.ndarray (and
+ dimensions must not be set).
Returns:
Persistence diagrams