summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-01 19:12:50 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-01 19:12:50 +0200
commit4a64eef12722de3faa8ac73416aaea91658e20b6 (patch)
treeed0c155822ae5e37398f3b528165bf895826b96e
parent5b75186ace327ddc17eb6f06c0ba2485c93235ec (diff)
Add cubical scikit learn interface documentation and example
-rw-r--r--src/python/doc/cubical_complex_sum.inc7
-rw-r--r--src/python/doc/cubical_complex_user.rst58
-rw-r--r--src/python/gudhi/sklearn/cubical_persistence.py9
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py8
4 files changed, 58 insertions, 24 deletions
diff --git a/src/python/doc/cubical_complex_sum.inc b/src/python/doc/cubical_complex_sum.inc
index 87db184d..2a1bde8d 100644
--- a/src/python/doc/cubical_complex_sum.inc
+++ b/src/python/doc/cubical_complex_sum.inc
@@ -3,12 +3,11 @@
+--------------------------------------------------------------------------+----------------------------------------------------------------------+-----------------------------+
| .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko |
- | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | |
- | :alt: Cubical complex representation | | :Since: GUDHI 2.0.0 |
+ | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | :Since: GUDHI 2.0.0 |
+ | :alt: Cubical complex representation | | :License: MIT |
| :figclass: align-center | | |
- | | | :License: MIT |
- | | | |
+--------------------------------------------------------------------------+----------------------------------------------------------------------+-----------------------------+
| * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
| | * :doc:`periodic_cubical_complex_ref` |
+ | | * :doc:`cubical_complex_sklearn_itf_ref` |
+--------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
diff --git a/src/python/doc/cubical_complex_user.rst b/src/python/doc/cubical_complex_user.rst
index 6a211347..12971243 100644
--- a/src/python/doc/cubical_complex_user.rst
+++ b/src/python/doc/cubical_complex_user.rst
@@ -7,14 +7,19 @@ Cubical complex user manual
Definition
----------
-===================================== ===================================== =====================================
-:Author: Pawel Dlotko :Since: GUDHI PYTHON 2.0.0 :License: GPL v3
-===================================== ===================================== =====================================
+.. list-table::
+ :widths: 25 50 25
+ :header-rows: 0
+
+ * - :Author: Pawel Dlotko
+ - :Since: GUDHI 2.0.0
+ - :License: MIT
+ * - :doc:`cubical_complex_user`
+ - * :doc:`cubical_complex_ref`
+ * :doc:`periodic_cubical_complex_ref`
+ * :doc:`cubical_complex_sklearn_itf_ref`
+ -
-+---------------------------------------------+----------------------------------------------------------------------+
-| :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
-| | * :doc:`periodic_cubical_complex_ref` |
-+---------------------------------------------+----------------------------------------------------------------------+
The cubical complex is an example of a structured complex useful in computational mathematics (specially rigorous
numerics) and image analysis.
@@ -163,4 +168,41 @@ Tutorial
--------
This `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-cubical-complexes.ipynb>`_
-explains how to represent sublevels sets of functions using cubical complexes. \ No newline at end of file
+explains how to represent sublevels sets of functions using cubical complexes.
+
+Scikit-learn like interface example
+-----------------------------------
+
+.. plot::
+ :include-source:
+
+ # Standard scientific Python imports
+ import matplotlib.pyplot as plt
+ from sklearn import datasets
+
+ # Import cubical persistence computation scikit-learn interfaces
+ from gudhi.sklearn.cubical_persistence import CubicalPersistence
+ # Import persistence representation
+ from gudhi.representations import PersistenceImage, DiagramSelector
+
+ # Get the first 10 images from scikit-learn hand digits dataset
+ digits = datasets.load_digits().images[:10]
+ targets = datasets.load_digits().target[:10]
+
+ # TDA pipeline
+ cub = CubicalPersistence(persistence_dim = 0, n_jobs=-2)
+ diags = cub.fit_transform(digits)
+
+ finite = DiagramSelector(use=True, point_type="finite")
+ finite_diags = finite.fit_transform(diags)
+
+ persim = PersistenceImage(im_range=[0,16,0,16], resolution=[16, 16])
+ pers_images = persim.fit_transform(finite_diags)
+
+ # Display persistence images
+ _, axes = plt.subplots(nrows=1, ncols=10, figsize=(15, 3))
+ for ax, image, label in zip(axes, pers_images, targets):
+ ax.set_axis_off()
+ ax.imshow(image.reshape(16, 16), cmap=plt.cm.gray_r, interpolation='nearest')
+ ax.set_title('Target: %i' % label)
+ plt.show()
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py
index 809f5d4b..a7a3d036 100644
--- a/src/python/gudhi/sklearn/cubical_persistence.py
+++ b/src/python/gudhi/sklearn/cubical_persistence.py
@@ -5,7 +5,7 @@ from joblib import Parallel, delayed
class CubicalPersistence(BaseEstimator, TransformerMixin):
# Fast way to find primes and should be enough
- available_primes_ = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]
+ _available_primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]
"""
This is a class for computing the persistence diagrams from a cubical complex.
"""
@@ -24,7 +24,7 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
self.persistence_dim = persistence_dim
self.homology_coeff_field_ = None
- for dim in self.available_primes_:
+ for dim in self._available_primes:
if dim > persistence_dim + 1:
self.homology_coeff_field_ = dim
break
@@ -45,14 +45,11 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
cubical_complex.compute_persistence(homology_coeff_field = self.homology_coeff_field_,
min_persistence = self.min_persistence)
diagrams = cubical_complex.persistence_intervals_in_dimension(self.persistence_dim)
- if self.persistence_dim == 0:
- # return all but the last, always [ 0., inf]
- diagrams = diagrams[:-1]
return diagrams
def transform(self, X, Y=None):
"""
- Compute all the cubical complexes and their persistence diagrams.
+ Compute all the cubical complexes and their associated persistence diagrams.
Parameters:
X (list of list of double OR list of numpy.ndarray): List of cells filtration values.
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
index 134611c9..c0082547 100644
--- a/src/python/test/test_sklearn_cubical_persistence.py
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -20,7 +20,7 @@ def test_simple_constructor_from_top_cells():
cells = datasets.load_digits().images[0]
cp = CubicalPersistence(persistence_dim = 0)
np.testing.assert_array_equal(cp._CubicalPersistence__transform(cells),
- np.array([[0., 6.], [0., 8.]]))
+ np.array([[0., 6.], [0., 8.], [ 0., np.inf]]))
def test_simple_constructor_from_top_cells_list():
digits = datasets.load_digits().images[:10]
@@ -29,8 +29,4 @@ def test_simple_constructor_from_top_cells_list():
diags = cp.fit_transform(digits)
assert len(diags) == 10
np.testing.assert_array_equal(diags[0],
- np.array([[0., 6.], [0., 8.]]))
-
-# from gudhi.representations import PersistenceImage
-# pi = PersistenceImage(bandwidth=50, weight=lambda x: x[1]**2, im_range=[0,256,0,256], resolution=[20, 20])
-# pi.fit_transform(diags) \ No newline at end of file
+ np.array([[0., 6.], [0., 8.], [ 0., np.inf]]))