summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVincent Rouvreau <vincent.rouvreau@inria.fr>2021-11-05 11:28:42 +0100
committerVincent Rouvreau <vincent.rouvreau@inria.fr>2021-11-05 11:28:42 +0100
commit8f14977760d05f8f08d2a7babdc197da27a6c53a (patch)
tree9d40c8dcb22812d923961a78e40add26d31ca8ab
parent44a80746c9cc5740e2cf27da52b9fc5fa7e682f1 (diff)
change doc according to proposal
-rw-r--r--src/python/doc/cubical_complex_sklearn_itf_ref.rst88
-rw-r--r--src/python/doc/cubical_complex_sum.inc24
-rw-r--r--src/python/doc/cubical_complex_user.rst95
3 files changed, 100 insertions, 107 deletions
diff --git a/src/python/doc/cubical_complex_sklearn_itf_ref.rst b/src/python/doc/cubical_complex_sklearn_itf_ref.rst
index b5c7a2e5..c585f9ab 100644
--- a/src/python/doc/cubical_complex_sklearn_itf_ref.rst
+++ b/src/python/doc/cubical_complex_sklearn_itf_ref.rst
@@ -2,8 +2,8 @@
.. To get rid of WARNING: document isn't included in any toctree
-Cubical complex persistence scikit-learn like interfaces reference manual
-#########################################################################
+Cubical complex persistence scikit-learn like interface
+#######################################################
.. list-table::
:widths: 40 30 30
@@ -13,8 +13,90 @@ Cubical complex persistence scikit-learn like interfaces reference manual
- :License: MIT
- :Requires: `Scikit-learn <installation.html#scikit-learn>`_
+Cubical complex persistence scikit-learn like interface example
+---------------------------------------------------------------
+
+In this example, hand written digits are used as an input.
+a TDA scikit-learn pipeline is constructed and is composed of:
+
+#. :class:`~gudhi.sklearn.cubical_persistence.CubicalPersistence` that builds a cubical complex from the inputs and
+ returns its persistence diagrams
+#. :class:`~gudhi.representations.DiagramSelector` that removes non-finite persistence diagrams values
+#. :class:`~gudhi.representations.PersistenceImage` that builds the persistence images from persistence diagrams
+#. `SVC <https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html>`_ which is a scikit-learn support
+ vector classifier.
+
+This ML pipeline is trained to detect if the hand written digit is an '8' or not, thanks to the fact that an '8' has
+two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected components in :math:`\mathbf{H}_0`.
+
+.. code-block:: python
+
+ # Standard scientific Python imports
+ import numpy as np
+
+ # Standard scikit-learn imports
+ from sklearn.datasets import fetch_openml
+ from sklearn.pipeline import Pipeline
+ from sklearn.model_selection import train_test_split
+ from sklearn.svm import SVC
+ from sklearn import metrics
+
+ # Import TDA pipeline requirements
+ from gudhi.sklearn.cubical_persistence import CubicalPersistence
+ from gudhi.representations import PersistenceImage, DiagramSelector
+
+ X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
+
+ # Target is: "is an eight ?"
+ y = (y == "8") * 1
+ print("There are", np.sum(y), "eights out of", len(y), "numbers.")
+
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
+ pipe = Pipeline(
+ [
+ ("cub_pers", CubicalPersistence(persistence_dimension=0, dimensions=[28, 28], n_jobs=-2)),
+ # Or for multiple persistence dimension computation
+ # ("cub_pers", CubicalPersistence(persistence_dimension=[0, 1], dimensions=[28, 28], n_jobs=-2)),
+ # ("H0_diags", DimensionSelector(index=0), # where index is the index in persistence_dimension array
+ ("finite_diags", DiagramSelector(use=True, point_type="finite")),
+ (
+ "pers_img",
+ PersistenceImage(bandwidth=50, weight=lambda x: x[1] ** 2, im_range=[0, 256, 0, 256], resolution=[20, 20]),
+ ),
+ ("svc", SVC()),
+ ]
+ )
+
+ # Learn from the train subset
+ pipe.fit(X_train, y_train)
+ # Predict from the test subset
+ predicted = pipe.predict(X_test)
+
+ print(f"Classification report for TDA pipeline {pipe}:\n" f"{metrics.classification_report(y_test, predicted)}\n")
+
+.. code-block:: none
+
+ There are 6825 eights out of 70000 numbers.
+ Classification report for TDA pipeline Pipeline(steps=[('cub_pers',
+ CubicalPersistence(dimensions=[28, 28], n_jobs=-2)),
+ ('finite_diags', DiagramSelector(use=True)),
+ ('pers_img',
+ PersistenceImage(bandwidth=50, im_range=[0, 256, 0, 256],
+ weight=<function <lambda> at 0x7f3e54137ae8>)),
+ ('svc', SVC())]):
+ precision recall f1-score support
+
+ 0 0.97 0.99 0.98 25284
+ 1 0.92 0.68 0.78 2716
+
+ accuracy 0.96 28000
+ macro avg 0.94 0.84 0.88 28000
+ weighted avg 0.96 0.96 0.96 28000
+
+Cubical complex persistence scikit-learn like interface reference
+-----------------------------------------------------------------
.. autoclass:: gudhi.sklearn.cubical_persistence.CubicalPersistence
:members:
:special-members: __init__
- :show-inheritance:
+ :show-inheritance: \ No newline at end of file
diff --git a/src/python/doc/cubical_complex_sum.inc b/src/python/doc/cubical_complex_sum.inc
index 2a1bde8d..e2fd55bb 100644
--- a/src/python/doc/cubical_complex_sum.inc
+++ b/src/python/doc/cubical_complex_sum.inc
@@ -1,13 +1,17 @@
.. table::
:widths: 30 40 30
- +--------------------------------------------------------------------------+----------------------------------------------------------------------+-----------------------------+
- | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko |
- | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | :Since: GUDHI 2.0.0 |
- | :alt: Cubical complex representation | | :License: MIT |
- | :figclass: align-center | | |
- +--------------------------------------------------------------------------+----------------------------------------------------------------------+-----------------------------+
- | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
- | | * :doc:`periodic_cubical_complex_ref` |
- | | * :doc:`cubical_complex_sklearn_itf_ref` |
- +--------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko |
+ | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | :Since: GUDHI 2.0.0 |
+ | :alt: Cubical complex representation | | :License: MIT |
+ | :figclass: align-center | | |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
+ | | * :doc:`periodic_cubical_complex_ref` |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | .. image:: | * :doc:`cubical_complex_sklearn_itf_ref` | :Requires: `Scikit-learn <installation.html#scikit-learn>`_ |
+ | img/sklearn.png | | |
+ | :target: https://scikit-learn.org | | |
+ | :height: 30 | | |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
diff --git a/src/python/doc/cubical_complex_user.rst b/src/python/doc/cubical_complex_user.rst
index e62a4395..42a23875 100644
--- a/src/python/doc/cubical_complex_user.rst
+++ b/src/python/doc/cubical_complex_user.rst
@@ -7,19 +7,7 @@ Cubical complex user manual
Definition
----------
-.. list-table::
- :widths: 25 50 25
- :header-rows: 0
-
- * - :Author: Pawel Dlotko
- - :Since: GUDHI 2.0.0
- - :License: MIT
- * - :doc:`cubical_complex_user`
- - * :doc:`cubical_complex_ref`
- * :doc:`periodic_cubical_complex_ref`
- * :doc:`cubical_complex_sklearn_itf_ref`
- -
-
+.. include:: cubical_complex_sum.inc
The cubical complex is an example of a structured complex useful in computational mathematics (specially rigorous
numerics) and image analysis.
@@ -169,84 +157,3 @@ Tutorial
This `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-cubical-complexes.ipynb>`_
explains how to represent sublevels sets of functions using cubical complexes.
-
-Scikit-learn like interface example
------------------------------------
-
-In this example, hand written digits are used as an input.
-a TDA scikit-learn pipeline is constructed and is composed of:
-
-#. :class:`~gudhi.sklearn.cubical_persistence.CubicalPersistence` that builds a cubical complex from the inputs and
- returns its persistence diagrams
-#. :class:`~gudhi.representations.DiagramSelector` that removes non-finite persistence diagrams values
-#. :class:`~gudhi.representations.PersistenceImage` that builds the persistence images from persistence diagrams
-#. `SVC <https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html>`_ which is a scikit-learn support
- vector classifier.
-
-This ML pipeline is trained to detect if the hand written digit is an '8' or not, thanks to the fact that an '8' has
-two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected components in :math:`\mathbf{H}_0`.
-
-.. code-block:: python
-
- # Standard scientific Python imports
- import numpy as np
-
- # Standard scikit-learn imports
- from sklearn.datasets import fetch_openml
- from sklearn.pipeline import Pipeline
- from sklearn.model_selection import train_test_split
- from sklearn.svm import SVC
- from sklearn import metrics
-
- # Import TDA pipeline requirements
- from gudhi.sklearn.cubical_persistence import CubicalPersistence
- from gudhi.representations import PersistenceImage, DiagramSelector
-
- X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
-
- # Target is: "is an eight ?"
- y = (y == "8") * 1
- print("There are", np.sum(y), "eights out of", len(y), "numbers.")
-
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
- pipe = Pipeline(
- [
- ("cub_pers", CubicalPersistence(persistence_dimension=0, dimensions=[28, 28], n_jobs=-2)),
- # Or for multiple persistence dimension computation
- # ("cub_pers", CubicalPersistence(persistence_dimension=[0, 1], dimensions=[28, 28], n_jobs=-2)),
- # ("H0_diags", DimensionSelector(index=0), # where index is the index in persistence_dimension array
- ("finite_diags", DiagramSelector(use=True, point_type="finite")),
- (
- "pers_img",
- PersistenceImage(bandwidth=50, weight=lambda x: x[1] ** 2, im_range=[0, 256, 0, 256], resolution=[20, 20]),
- ),
- ("svc", SVC()),
- ]
- )
-
- # Learn from the train subset
- pipe.fit(X_train, y_train)
- # Predict from the test subset
- predicted = pipe.predict(X_test)
-
- print(f"Classification report for TDA pipeline {pipe}:\n" f"{metrics.classification_report(y_test, predicted)}\n")
-
-
-.. code-block:: none
-
- There are 6825 eights out of 70000 numbers.
- Classification report for TDA pipeline Pipeline(steps=[('cub_pers',
- CubicalPersistence(dimensions=[28, 28], n_jobs=-2)),
- ('finite_diags', DiagramSelector(use=True)),
- ('pers_img',
- PersistenceImage(bandwidth=50, im_range=[0, 256, 0, 256],
- weight=<function <lambda> at 0x7f3e54137ae8>)),
- ('svc', SVC())]):
- precision recall f1-score support
-
- 0 0.97 0.99 0.98 25284
- 1 0.92 0.68 0.78 2716
-
- accuracy 0.96 28000
- macro avg 0.94 0.84 0.88 28000
- weighted avg 0.96 0.96 0.96 28000 \ No newline at end of file