summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-04 11:56:59 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-04 11:56:59 +0200
commit546b059af6c0581d06bfe9cebbe853f2f7bd4589 (patch)
treef99d9cf618223ad26735f370c3aedbe6c27765af
parent4a64eef12722de3faa8ac73416aaea91658e20b6 (diff)
Add a more relevant example inspired from https://dioscuri-tda.org/Paris_TDA_Tutorial_2021.html
-rw-r--r--src/python/doc/cubical_complex_user.rst66
1 files changed, 43 insertions, 23 deletions
diff --git a/src/python/doc/cubical_complex_user.rst b/src/python/doc/cubical_complex_user.rst
index 12971243..ebecb592 100644
--- a/src/python/doc/cubical_complex_user.rst
+++ b/src/python/doc/cubical_complex_user.rst
@@ -173,36 +173,56 @@ explains how to represent sublevels sets of functions using cubical complexes.
Scikit-learn like interface example
-----------------------------------
-.. plot::
- :include-source:
+.. code-block:: python
# Standard scientific Python imports
- import matplotlib.pyplot as plt
- from sklearn import datasets
+ import numpy as np
+ # Standard scikit-learn imports
+ from sklearn.datasets import fetch_openml
+ from sklearn.pipeline import Pipeline
+ from sklearn.model_selection import train_test_split
+ from sklearn.svm import SVC
+ from sklearn import metrics
- # Import cubical persistence computation scikit-learn interfaces
+ # Import TDA pipeline requirements
from gudhi.sklearn.cubical_persistence import CubicalPersistence
- # Import persistence representation
from gudhi.representations import PersistenceImage, DiagramSelector
- # Get the first 10 images from scikit-learn hand digits dataset
- digits = datasets.load_digits().images[:10]
- targets = datasets.load_digits().target[:10]
+ X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)
- # TDA pipeline
- cub = CubicalPersistence(persistence_dim = 0, n_jobs=-2)
- diags = cub.fit_transform(digits)
+ # Target is: "is an eight ?"
+ y = (y == '8') * 1
+ print('There are', np.sum(y), 'eights out of', len(y), 'numbers.')
- finite = DiagramSelector(use=True, point_type="finite")
- finite_diags = finite.fit_transform(diags)
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
+ pipe = Pipeline([('cub_pers', CubicalPersistence(persistence_dim = 0, dimensions=[28,28], n_jobs=-2)),
+ ('finite_diags', DiagramSelector(use=True, point_type="finite")),
+ ('pers_img', PersistenceImage(bandwidth=50,
+ weight=lambda x: x[1]**2,
+ im_range=[0,256,0,256],
+ resolution=[20,20])),
+ ('svc', SVC())])
- persim = PersistenceImage(im_range=[0,16,0,16], resolution=[16, 16])
- pers_images = persim.fit_transform(finite_diags)
+ predicted = pipe.predict(X_test)
- # Display persistence images
- _, axes = plt.subplots(nrows=1, ncols=10, figsize=(15, 3))
- for ax, image, label in zip(axes, pers_images, targets):
- ax.set_axis_off()
- ax.imshow(image.reshape(16, 16), cmap=plt.cm.gray_r, interpolation='nearest')
- ax.set_title('Target: %i' % label)
- plt.show()
+ print(f"Classification report for TDA pipeline {pipe}:\n"
+ f"{metrics.classification_report(y_test, predicted)}\n")
+
+.. code-block:: none
+
+ There are 6825 eights out of 70000 numbers.
+ Classification report for TDA pipeline Pipeline(steps=[('cub_pers',
+ CubicalPersistence(dimensions=[28, 28], n_jobs=-2)),
+ ('finite_diags', DiagramSelector(use=True)),
+ ('pers_img',
+ PersistenceImage(bandwidth=50, im_range=[0, 256, 0, 256],
+ weight=<function <lambda> at 0x7f3e54137ae8>)),
+ ('svc', SVC())]):
+ precision recall f1-score support
+
+ 0 0.97 0.99 0.98 25284
+ 1 0.92 0.68 0.78 2716
+
+ accuracy 0.96 28000
+ macro avg 0.94 0.84 0.88 28000
+ weighted avg 0.96 0.96 0.96 28000 \ No newline at end of file