diff options
author | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-08-09 10:38:31 +0200 |
---|---|---|
committer | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-08-09 10:38:31 +0200 |
commit | 5c35605763273cb34efe4227b6d748992e99ab09 (patch) | |
tree | fe2c61809e63fe53165c2252639f53724a6ebb6d | |
parent | 91d72a69f2f04676fbd671af3dc2f3040c9f1c48 (diff) |
Make CubicalPersistence returns all dimensions. Post processing DimensionSelector can select the desired dimension
-rw-r--r-- | src/python/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/python/doc/cubical_complex_user.rst | 2 | ||||
-rw-r--r-- | src/python/gudhi/sklearn/cubical_persistence.py | 49 | ||||
-rw-r--r-- | src/python/gudhi/sklearn/post_processing.py | 61 | ||||
-rw-r--r-- | src/python/test/test_sklearn_cubical_persistence.py | 21 | ||||
-rw-r--r-- | src/python/test/test_sklearn_post_processing.py | 48 |
6 files changed, 167 insertions, 15 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index a91aab37..b38bb9aa 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -546,6 +546,7 @@ if(PYTHONINTERP_FOUND) # sklearn if(SKLEARN_FOUND) add_gudhi_py_test(test_sklearn_cubical_persistence) + add_gudhi_py_test(test_sklearn_post_processing) endif() diff --git a/src/python/doc/cubical_complex_user.rst b/src/python/doc/cubical_complex_user.rst index 3fd9fd84..a140a279 100644 --- a/src/python/doc/cubical_complex_user.rst +++ b/src/python/doc/cubical_complex_user.rst @@ -211,7 +211,7 @@ two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected com X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0) pipe = Pipeline( [ - ("cub_pers", CubicalPersistence(persistence_dim=0, dimensions=[28, 28], n_jobs=-2)), + ("cub_pers", CubicalPersistence(only_this_dim=0, dimensions=[28, 28], n_jobs=-2)), ("finite_diags", DiagramSelector(use=True, point_type="finite")), ( "pers_img", diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py index 9af683d7..7b77000d 100644 --- a/src/python/gudhi/sklearn/cubical_persistence.py +++ b/src/python/gudhi/sklearn/cubical_persistence.py @@ -13,27 +13,44 @@ from sklearn.base import BaseEstimator, TransformerMixin # joblib is required by scikit-learn from joblib import Parallel, delayed +# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/ +# sequenceDiagram +# USER->>CubicalPersistence: fit_transform(X) +# CubicalPersistence->>thread1: _tranform(X[0]) +# CubicalPersistence->>thread2: _tranform(X[1]) +# Note right of CubicalPersistence: ... +# thread1->>CubicalPersistence: [array( H0(X[0]) ), array( H1(X[0]) )] +# thread2->>CubicalPersistence: [array( H0(X[1]) ), array( H1(X[1]) )] +# Note right of CubicalPersistence: ... +# CubicalPersistence->>USER: [[array( H0(X[0]) ), array( H1(X[0]) )],<br/> [array( H0(X[1]) ), array( H1(X[1]) )],<br/> ...] + class CubicalPersistence(BaseEstimator, TransformerMixin): """ This is a class for computing the persistence diagrams from a cubical complex. """ - def __init__(self, dimensions=None, persistence_dim=0, homology_coeff_field=11, min_persistence=0., n_jobs=None): + def __init__(self, dimensions=None, max_persistence_dimension=0, only_this_dim=-1, homology_coeff_field=11, min_persistence=0., n_jobs=None): """ Constructor for the CubicalPersistence class. Parameters: dimensions (list of int): A list of number of top dimensional cells if cells filtration values will require to be reshaped (cf. :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform`) - persistence_dim (int): The returned persistence diagrams dimension. Default value is `0`. + max_persistence_dimension (int): The returned persistence diagrams maximal dimension. Default value is `0`. + Ignored if `only_this_dim` is set. + only_this_dim (int): The returned persistence diagrams dimension. If `only_this_dim` is set, + `max_persistence_dimension` will be ignored. + Short circuit the use of :class:`~gudhi.sklearn.post_processing.DimensionSelector` when only one + dimension matters. homology_coeff_field (int): The homology coefficient field. Must be a prime number. Default value is 11. min_persistence (float): The minimum persistence value to take into account (strictly greater than `min_persistence`). Default value is `0.0`. Sets `min_persistence` to `-1.0` to see all values. n_jobs (int): cf. https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html """ self.dimensions = dimensions - self.persistence_dim = persistence_dim + self.max_persistence_dimension = max_persistence_dimension + self.only_this_dim = only_this_dim self.homology_coeff_field = homology_coeff_field self.min_persistence = min_persistence self.n_jobs = n_jobs @@ -49,8 +66,14 @@ class CubicalPersistence(BaseEstimator, TransformerMixin): cubical_complex.compute_persistence( homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence ) - diagrams = cubical_complex.persistence_intervals_in_dimension(self.persistence_dim) - return diagrams + return [cubical_complex.persistence_intervals_in_dimension(dim) for dim in range(self.max_persistence_dimension + 1)] + + def __transform_only_this_dim(self, cells): + cubical_complex = CubicalComplex(top_dimensional_cells=cells, dimensions=self.dimensions) + cubical_complex.compute_persistence( + homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence + ) + return cubical_complex.persistence_intervals_in_dimension(self.only_this_dim) def transform(self, X, Y=None): """ @@ -58,12 +81,18 @@ class CubicalPersistence(BaseEstimator, TransformerMixin): Parameters: X (list of list of double OR list of numpy.ndarray): List of cells filtration values that can be flatten if - dimensions is set in the constructor, or already with the correct shape in a numpy.ndarray (and - dimensions must not be set). + `dimensions` is set in the constructor, or already with the correct shape in a numpy.ndarray (and + `dimensions` must not be set). Returns: - Persistence diagrams + Persistence diagrams in the format: + - If `only_this_dim` was set to `n`: `[array( Hn(X[0]) ), array( Hn(X[1]) ), ...]` + - else: `[[array( H0(X[0]) ), array( H1(X[0]) ), ...], [array( H0(X[1]) ), array( H1(X[1]) ), ...], ...]` """ - # threads is preferred as cubical construction and persistence computation releases the GIL - return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(cells) for cells in X) + if self.only_this_dim == -1: + # threads is preferred as cubical construction and persistence computation releases the GIL + return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(cells) for cells in X) + else: + # threads is preferred as cubical construction and persistence computation releases the GIL + return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform_only_this_dim)(cells) for cells in X) diff --git a/src/python/gudhi/sklearn/post_processing.py b/src/python/gudhi/sklearn/post_processing.py new file mode 100644 index 00000000..79276e1e --- /dev/null +++ b/src/python/gudhi/sklearn/post_processing.py @@ -0,0 +1,61 @@ +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# Author(s): Vincent Rouvreau +# +# Copyright (C) 2021 Inria +# +# Modification(s): +# - YYYY/MM Author: Description of the modification + +from sklearn.base import BaseEstimator, TransformerMixin + +# joblib is required by scikit-learn +from joblib import Parallel, delayed + +# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/ +# sequenceDiagram +# USER->>DimensionSelector: fit_transform(<br/>[[array( H0(X0) ), array( H1(X0) ), ...],<br/> [array( H0(X1) ), array( H1(X1) ), ...],<br/> ...]) +# DimensionSelector->>thread1: _transform([array( H0(X0) ), array( H1(X0) )], ...) +# DimensionSelector->>thread2: _transform([array( H0(X1) ), array( H1(X1) )], ...) +# Note right of DimensionSelector: ... +# thread1->>DimensionSelector: array( Hn(X0) ) +# thread2->>DimensionSelector: array( Hn(X1) ) +# Note right of DimensionSelector: ... +# DimensionSelector->>USER: [array( Hn(X0) ), <br/> array( Hn(X1) ), <br/> ...] + + +class DimensionSelector(BaseEstimator, TransformerMixin): + """ + This is a class to select persistence diagrams in a specific dimension. + """ + + def __init__(self, persistence_dimension=0, n_jobs=None): + """ + Constructor for the DimensionSelector class. + + Parameters: + persistence_dimension (int): The returned persistence diagrams dimension. Default value is `0`. + """ + self.persistence_dimension = persistence_dimension + self.n_jobs = n_jobs + + def fit(self, X, Y=None): + """ + Nothing to be done, but useful when included in a scikit-learn Pipeline. + """ + return self + + def transform(self, X, Y=None): + """ + Select persistence diagrams from its dimension. + + Parameters: + X (list of list of pairs): List of list of persistence pairs, i.e. + `[[array( H0(X0) ), array( H1(X0) ), ...], [array( H0(X1) ), array( H1(X1) ), ...], ...]` + + Returns: + Persistence diagrams in a specific dimension, i.e. + `[array( Hn(X0) ), array( Hn(X1), ...]` + """ + + return [persistence[self.persistence_dimension] for persistence in X] diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py index c0082547..506985f1 100644 --- a/src/python/test/test_sklearn_cubical_persistence.py +++ b/src/python/test/test_sklearn_cubical_persistence.py @@ -16,17 +16,30 @@ __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2021 Inria" __license__ = "MIT" +CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0., 6.], [0., 8.], [ 0., np.inf]]) + def test_simple_constructor_from_top_cells(): cells = datasets.load_digits().images[0] - cp = CubicalPersistence(persistence_dim = 0) + cp = CubicalPersistence(only_this_dim = 0) np.testing.assert_array_equal(cp._CubicalPersistence__transform(cells), - np.array([[0., 6.], [0., 8.], [ 0., np.inf]])) + [CUBICAL_PERSISTENCE_H0_IMG0]) + cp = CubicalPersistence(max_persistence_dimension = 2) + diags = cp._CubicalPersistence__transform(cells) + assert len(diags) == 3 + np.testing.assert_array_equal(diags[0], + CUBICAL_PERSISTENCE_H0_IMG0) def test_simple_constructor_from_top_cells_list(): digits = datasets.load_digits().images[:10] - cp = CubicalPersistence(persistence_dim = 0, n_jobs=-2) + cp = CubicalPersistence(only_this_dim = 0, n_jobs=-2) diags = cp.fit_transform(digits) assert len(diags) == 10 np.testing.assert_array_equal(diags[0], - np.array([[0., 6.], [0., 8.], [ 0., np.inf]])) + CUBICAL_PERSISTENCE_H0_IMG0) + + cp = CubicalPersistence(max_persistence_dimension = 1, n_jobs=-1) + diagsH0H1 = cp.fit_transform(digits) + assert len(diagsH0H1) == 10 + for idx in range(10): + np.testing.assert_array_equal(diags[idx], diagsH0H1[idx][0]) diff --git a/src/python/test/test_sklearn_post_processing.py b/src/python/test/test_sklearn_post_processing.py new file mode 100644 index 00000000..3a251d34 --- /dev/null +++ b/src/python/test/test_sklearn_post_processing.py @@ -0,0 +1,48 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Vincent Rouvreau + + Copyright (C) 2021 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +from gudhi.sklearn.post_processing import DimensionSelector +import numpy as np +import pytest + +__author__ = "Vincent Rouvreau" +__copyright__ = "Copyright (C) 2021 Inria" +__license__ = "MIT" + +H0_0 = np.array([0., 0.]) +H1_0 = np.array([1., 0.]) +H0_1 = np.array([0., 1.]) +H1_1 = np.array([1., 1.]) +H0_2 = np.array([0., 2.]) +H1_2 = np.array([1., 2.]) + +def test_dimension_selector(): + X = [[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]] + ds = DimensionSelector(persistence_dimension = 0, n_jobs=-2) + h0 = ds.fit_transform(X) + np.testing.assert_array_equal(h0[0], + H0_0) + np.testing.assert_array_equal(h0[1], + H0_1) + np.testing.assert_array_equal(h0[2], + H0_2) + + ds = DimensionSelector(persistence_dimension = 1, n_jobs=-1) + h1 = ds.fit_transform(X) + np.testing.assert_array_equal(h1[0], + H1_0) + np.testing.assert_array_equal(h1[1], + H1_1) + np.testing.assert_array_equal(h1[2], + H1_2) + + ds = DimensionSelector(persistence_dimension = 2, n_jobs=-2) + with pytest.raises(IndexError): + h2 = ds.fit_transform([[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]]) |