From 546b059af6c0581d06bfe9cebbe853f2f7bd4589 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Fri, 4 Jun 2021 11:56:59 +0200 Subject: Add a more relevant example inspired from https://dioscuri-tda.org/Paris_TDA_Tutorial_2021.html --- src/python/doc/cubical_complex_user.rst | 66 +++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 23 deletions(-) (limited to 'src/python/doc/cubical_complex_user.rst') diff --git a/src/python/doc/cubical_complex_user.rst b/src/python/doc/cubical_complex_user.rst index 12971243..ebecb592 100644 --- a/src/python/doc/cubical_complex_user.rst +++ b/src/python/doc/cubical_complex_user.rst @@ -173,36 +173,56 @@ explains how to represent sublevels sets of functions using cubical complexes. Scikit-learn like interface example ----------------------------------- -.. plot:: - :include-source: +.. code-block:: python # Standard scientific Python imports - import matplotlib.pyplot as plt - from sklearn import datasets + import numpy as np + # Standard scikit-learn imports + from sklearn.datasets import fetch_openml + from sklearn.pipeline import Pipeline + from sklearn.model_selection import train_test_split + from sklearn.svm import SVC + from sklearn import metrics - # Import cubical persistence computation scikit-learn interfaces + # Import TDA pipeline requirements from gudhi.sklearn.cubical_persistence import CubicalPersistence - # Import persistence representation from gudhi.representations import PersistenceImage, DiagramSelector - # Get the first 10 images from scikit-learn hand digits dataset - digits = datasets.load_digits().images[:10] - targets = datasets.load_digits().target[:10] + X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False) - # TDA pipeline - cub = CubicalPersistence(persistence_dim = 0, n_jobs=-2) - diags = cub.fit_transform(digits) + # Target is: "is an eight ?" + y = (y == '8') * 1 + print('There are', np.sum(y), 'eights out of', len(y), 'numbers.') - finite = DiagramSelector(use=True, point_type="finite") - finite_diags = finite.fit_transform(diags) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0) + pipe = Pipeline([('cub_pers', CubicalPersistence(persistence_dim = 0, dimensions=[28,28], n_jobs=-2)), + ('finite_diags', DiagramSelector(use=True, point_type="finite")), + ('pers_img', PersistenceImage(bandwidth=50, + weight=lambda x: x[1]**2, + im_range=[0,256,0,256], + resolution=[20,20])), + ('svc', SVC())]) - persim = PersistenceImage(im_range=[0,16,0,16], resolution=[16, 16]) - pers_images = persim.fit_transform(finite_diags) + predicted = pipe.predict(X_test) - # Display persistence images - _, axes = plt.subplots(nrows=1, ncols=10, figsize=(15, 3)) - for ax, image, label in zip(axes, pers_images, targets): - ax.set_axis_off() - ax.imshow(image.reshape(16, 16), cmap=plt.cm.gray_r, interpolation='nearest') - ax.set_title('Target: %i' % label) - plt.show() + print(f"Classification report for TDA pipeline {pipe}:\n" + f"{metrics.classification_report(y_test, predicted)}\n") + +.. code-block:: none + + There are 6825 eights out of 70000 numbers. + Classification report for TDA pipeline Pipeline(steps=[('cub_pers', + CubicalPersistence(dimensions=[28, 28], n_jobs=-2)), + ('finite_diags', DiagramSelector(use=True)), + ('pers_img', + PersistenceImage(bandwidth=50, im_range=[0, 256, 0, 256], + weight= at 0x7f3e54137ae8>)), + ('svc', SVC())]): + precision recall f1-score support + + 0 0.97 0.99 0.98 25284 + 1 0.92 0.68 0.78 2716 + + accuracy 0.96 28000 + macro avg 0.94 0.84 0.88 28000 + weighted avg 0.96 0.96 0.96 28000 \ No newline at end of file -- cgit v1.2.3