summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-08-09 10:38:31 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-08-09 10:38:31 +0200
commit5c35605763273cb34efe4227b6d748992e99ab09 (patch)
treefe2c61809e63fe53165c2252639f53724a6ebb6d
parent91d72a69f2f04676fbd671af3dc2f3040c9f1c48 (diff)
Make CubicalPersistence returns all dimensions. Post processing DimensionSelector can select the desired dimension
-rw-r--r--src/python/CMakeLists.txt1
-rw-r--r--src/python/doc/cubical_complex_user.rst2
-rw-r--r--src/python/gudhi/sklearn/cubical_persistence.py49
-rw-r--r--src/python/gudhi/sklearn/post_processing.py61
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py21
-rw-r--r--src/python/test/test_sklearn_post_processing.py48
6 files changed, 167 insertions, 15 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index a91aab37..b38bb9aa 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -546,6 +546,7 @@ if(PYTHONINTERP_FOUND)
# sklearn
if(SKLEARN_FOUND)
add_gudhi_py_test(test_sklearn_cubical_persistence)
+ add_gudhi_py_test(test_sklearn_post_processing)
endif()
diff --git a/src/python/doc/cubical_complex_user.rst b/src/python/doc/cubical_complex_user.rst
index 3fd9fd84..a140a279 100644
--- a/src/python/doc/cubical_complex_user.rst
+++ b/src/python/doc/cubical_complex_user.rst
@@ -211,7 +211,7 @@ two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected com
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
pipe = Pipeline(
[
- ("cub_pers", CubicalPersistence(persistence_dim=0, dimensions=[28, 28], n_jobs=-2)),
+ ("cub_pers", CubicalPersistence(only_this_dim=0, dimensions=[28, 28], n_jobs=-2)),
("finite_diags", DiagramSelector(use=True, point_type="finite")),
(
"pers_img",
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py
index 9af683d7..7b77000d 100644
--- a/src/python/gudhi/sklearn/cubical_persistence.py
+++ b/src/python/gudhi/sklearn/cubical_persistence.py
@@ -13,27 +13,44 @@ from sklearn.base import BaseEstimator, TransformerMixin
# joblib is required by scikit-learn
from joblib import Parallel, delayed
+# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
+# sequenceDiagram
+# USER->>CubicalPersistence: fit_transform(X)
+# CubicalPersistence->>thread1: _tranform(X[0])
+# CubicalPersistence->>thread2: _tranform(X[1])
+# Note right of CubicalPersistence: ...
+# thread1->>CubicalPersistence: [array( H0(X[0]) ), array( H1(X[0]) )]
+# thread2->>CubicalPersistence: [array( H0(X[1]) ), array( H1(X[1]) )]
+# Note right of CubicalPersistence: ...
+# CubicalPersistence->>USER: [[array( H0(X[0]) ), array( H1(X[0]) )],<br/> [array( H0(X[1]) ), array( H1(X[1]) )],<br/> ...]
+
class CubicalPersistence(BaseEstimator, TransformerMixin):
"""
This is a class for computing the persistence diagrams from a cubical complex.
"""
- def __init__(self, dimensions=None, persistence_dim=0, homology_coeff_field=11, min_persistence=0., n_jobs=None):
+ def __init__(self, dimensions=None, max_persistence_dimension=0, only_this_dim=-1, homology_coeff_field=11, min_persistence=0., n_jobs=None):
"""
Constructor for the CubicalPersistence class.
Parameters:
dimensions (list of int): A list of number of top dimensional cells if cells filtration values will require
to be reshaped (cf. :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform`)
- persistence_dim (int): The returned persistence diagrams dimension. Default value is `0`.
+ max_persistence_dimension (int): The returned persistence diagrams maximal dimension. Default value is `0`.
+ Ignored if `only_this_dim` is set.
+ only_this_dim (int): The returned persistence diagrams dimension. If `only_this_dim` is set,
+ `max_persistence_dimension` will be ignored.
+ Short circuit the use of :class:`~gudhi.sklearn.post_processing.DimensionSelector` when only one
+ dimension matters.
homology_coeff_field (int): The homology coefficient field. Must be a prime number. Default value is 11.
min_persistence (float): The minimum persistence value to take into account (strictly greater than
`min_persistence`). Default value is `0.0`. Sets `min_persistence` to `-1.0` to see all values.
n_jobs (int): cf. https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html
"""
self.dimensions = dimensions
- self.persistence_dim = persistence_dim
+ self.max_persistence_dimension = max_persistence_dimension
+ self.only_this_dim = only_this_dim
self.homology_coeff_field = homology_coeff_field
self.min_persistence = min_persistence
self.n_jobs = n_jobs
@@ -49,8 +66,14 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
cubical_complex.compute_persistence(
homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence
)
- diagrams = cubical_complex.persistence_intervals_in_dimension(self.persistence_dim)
- return diagrams
+ return [cubical_complex.persistence_intervals_in_dimension(dim) for dim in range(self.max_persistence_dimension + 1)]
+
+ def __transform_only_this_dim(self, cells):
+ cubical_complex = CubicalComplex(top_dimensional_cells=cells, dimensions=self.dimensions)
+ cubical_complex.compute_persistence(
+ homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence
+ )
+ return cubical_complex.persistence_intervals_in_dimension(self.only_this_dim)
def transform(self, X, Y=None):
"""
@@ -58,12 +81,18 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
Parameters:
X (list of list of double OR list of numpy.ndarray): List of cells filtration values that can be flatten if
- dimensions is set in the constructor, or already with the correct shape in a numpy.ndarray (and
- dimensions must not be set).
+ `dimensions` is set in the constructor, or already with the correct shape in a numpy.ndarray (and
+ `dimensions` must not be set).
Returns:
- Persistence diagrams
+ Persistence diagrams in the format:
+ - If `only_this_dim` was set to `n`: `[array( Hn(X[0]) ), array( Hn(X[1]) ), ...]`
+ - else: `[[array( H0(X[0]) ), array( H1(X[0]) ), ...], [array( H0(X[1]) ), array( H1(X[1]) ), ...], ...]`
"""
- # threads is preferred as cubical construction and persistence computation releases the GIL
- return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(cells) for cells in X)
+ if self.only_this_dim == -1:
+ # threads is preferred as cubical construction and persistence computation releases the GIL
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(cells) for cells in X)
+ else:
+ # threads is preferred as cubical construction and persistence computation releases the GIL
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform_only_this_dim)(cells) for cells in X)
diff --git a/src/python/gudhi/sklearn/post_processing.py b/src/python/gudhi/sklearn/post_processing.py
new file mode 100644
index 00000000..79276e1e
--- /dev/null
+++ b/src/python/gudhi/sklearn/post_processing.py
@@ -0,0 +1,61 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Vincent Rouvreau
+#
+# Copyright (C) 2021 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+from sklearn.base import BaseEstimator, TransformerMixin
+
+# joblib is required by scikit-learn
+from joblib import Parallel, delayed
+
+# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
+# sequenceDiagram
+# USER->>DimensionSelector: fit_transform(<br/>[[array( H0(X0) ), array( H1(X0) ), ...],<br/> [array( H0(X1) ), array( H1(X1) ), ...],<br/> ...])
+# DimensionSelector->>thread1: _transform([array( H0(X0) ), array( H1(X0) )], ...)
+# DimensionSelector->>thread2: _transform([array( H0(X1) ), array( H1(X1) )], ...)
+# Note right of DimensionSelector: ...
+# thread1->>DimensionSelector: array( Hn(X0) )
+# thread2->>DimensionSelector: array( Hn(X1) )
+# Note right of DimensionSelector: ...
+# DimensionSelector->>USER: [array( Hn(X0) ), <br/> array( Hn(X1) ), <br/> ...]
+
+
+class DimensionSelector(BaseEstimator, TransformerMixin):
+ """
+ This is a class to select persistence diagrams in a specific dimension.
+ """
+
+ def __init__(self, persistence_dimension=0, n_jobs=None):
+ """
+ Constructor for the DimensionSelector class.
+
+ Parameters:
+ persistence_dimension (int): The returned persistence diagrams dimension. Default value is `0`.
+ """
+ self.persistence_dimension = persistence_dimension
+ self.n_jobs = n_jobs
+
+ def fit(self, X, Y=None):
+ """
+ Nothing to be done, but useful when included in a scikit-learn Pipeline.
+ """
+ return self
+
+ def transform(self, X, Y=None):
+ """
+ Select persistence diagrams from its dimension.
+
+ Parameters:
+ X (list of list of pairs): List of list of persistence pairs, i.e.
+ `[[array( H0(X0) ), array( H1(X0) ), ...], [array( H0(X1) ), array( H1(X1) ), ...], ...]`
+
+ Returns:
+ Persistence diagrams in a specific dimension, i.e.
+ `[array( Hn(X0) ), array( Hn(X1), ...]`
+ """
+
+ return [persistence[self.persistence_dimension] for persistence in X]
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
index c0082547..506985f1 100644
--- a/src/python/test/test_sklearn_cubical_persistence.py
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -16,17 +16,30 @@ __author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2021 Inria"
__license__ = "MIT"
+CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0., 6.], [0., 8.], [ 0., np.inf]])
+
def test_simple_constructor_from_top_cells():
cells = datasets.load_digits().images[0]
- cp = CubicalPersistence(persistence_dim = 0)
+ cp = CubicalPersistence(only_this_dim = 0)
np.testing.assert_array_equal(cp._CubicalPersistence__transform(cells),
- np.array([[0., 6.], [0., 8.], [ 0., np.inf]]))
+ [CUBICAL_PERSISTENCE_H0_IMG0])
+ cp = CubicalPersistence(max_persistence_dimension = 2)
+ diags = cp._CubicalPersistence__transform(cells)
+ assert len(diags) == 3
+ np.testing.assert_array_equal(diags[0],
+ CUBICAL_PERSISTENCE_H0_IMG0)
def test_simple_constructor_from_top_cells_list():
digits = datasets.load_digits().images[:10]
- cp = CubicalPersistence(persistence_dim = 0, n_jobs=-2)
+ cp = CubicalPersistence(only_this_dim = 0, n_jobs=-2)
diags = cp.fit_transform(digits)
assert len(diags) == 10
np.testing.assert_array_equal(diags[0],
- np.array([[0., 6.], [0., 8.], [ 0., np.inf]]))
+ CUBICAL_PERSISTENCE_H0_IMG0)
+
+ cp = CubicalPersistence(max_persistence_dimension = 1, n_jobs=-1)
+ diagsH0H1 = cp.fit_transform(digits)
+ assert len(diagsH0H1) == 10
+ for idx in range(10):
+ np.testing.assert_array_equal(diags[idx], diagsH0H1[idx][0])
diff --git a/src/python/test/test_sklearn_post_processing.py b/src/python/test/test_sklearn_post_processing.py
new file mode 100644
index 00000000..3a251d34
--- /dev/null
+++ b/src/python/test/test_sklearn_post_processing.py
@@ -0,0 +1,48 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.sklearn.post_processing import DimensionSelector
+import numpy as np
+import pytest
+
+__author__ = "Vincent Rouvreau"
+__copyright__ = "Copyright (C) 2021 Inria"
+__license__ = "MIT"
+
+H0_0 = np.array([0., 0.])
+H1_0 = np.array([1., 0.])
+H0_1 = np.array([0., 1.])
+H1_1 = np.array([1., 1.])
+H0_2 = np.array([0., 2.])
+H1_2 = np.array([1., 2.])
+
+def test_dimension_selector():
+ X = [[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]]
+ ds = DimensionSelector(persistence_dimension = 0, n_jobs=-2)
+ h0 = ds.fit_transform(X)
+ np.testing.assert_array_equal(h0[0],
+ H0_0)
+ np.testing.assert_array_equal(h0[1],
+ H0_1)
+ np.testing.assert_array_equal(h0[2],
+ H0_2)
+
+ ds = DimensionSelector(persistence_dimension = 1, n_jobs=-1)
+ h1 = ds.fit_transform(X)
+ np.testing.assert_array_equal(h1[0],
+ H1_0)
+ np.testing.assert_array_equal(h1[1],
+ H1_1)
+ np.testing.assert_array_equal(h1[2],
+ H1_2)
+
+ ds = DimensionSelector(persistence_dimension = 2, n_jobs=-2)
+ with pytest.raises(IndexError):
+ h2 = ds.fit_transform([[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]])