summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVincent Rouvreau <vincent.rouvreau@inria.fr>2022-06-20 17:07:51 +0200
committerVincent Rouvreau <vincent.rouvreau@inria.fr>2022-06-20 17:07:51 +0200
commit8d7244b510eee0d7927c521117198ef286028acf (patch)
tree4ff8da12f39e52f955e506e0e20a1e16b0aa8f8e
parent3bfc066548fbdbd9b1dc06b39d8ccecd2ce4d0b5 (diff)
code review: rename homology_dimensions argument. Use and document numpy reshape instead of cubical dimension argument
-rw-r--r--src/python/doc/cubical_complex_sklearn_itf_ref.rst6
-rw-r--r--src/python/gudhi/sklearn/cubical_persistence.py38
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py21
3 files changed, 39 insertions, 26 deletions
diff --git a/src/python/doc/cubical_complex_sklearn_itf_ref.rst b/src/python/doc/cubical_complex_sklearn_itf_ref.rst
index 8248343b..94862541 100644
--- a/src/python/doc/cubical_complex_sklearn_itf_ref.rst
+++ b/src/python/doc/cubical_complex_sklearn_itf_ref.rst
@@ -54,10 +54,10 @@ two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected com
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
pipe = Pipeline(
[
- ("cub_pers", CubicalPersistence(persistence_dimension=0, newshape=[28, 28], n_jobs=-2)),
+ ("cub_pers", CubicalPersistence(homology_dimensions=0, newshape=[28, 28], n_jobs=-2)),
# Or for multiple persistence dimension computation
- # ("cub_pers", CubicalPersistence(persistence_dimension=[0, 1], newshape=[28, 28], n_jobs=-2)),
- # ("H0_diags", DimensionSelector(index=0), # where index is the index in persistence_dimension array
+ # ("cub_pers", CubicalPersistence(homology_dimensions=[0, 1], newshape=[28, 28], n_jobs=-2)),
+ # ("H0_diags", DimensionSelector(index=0), # where index is the index in homology_dimensions array
("finite_diags", DiagramSelector(use=True, point_type="finite")),
(
"pers_img",
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py
index dc7be7f5..06e9128b 100644
--- a/src/python/gudhi/sklearn/cubical_persistence.py
+++ b/src/python/gudhi/sklearn/cubical_persistence.py
@@ -10,6 +10,7 @@
from .. import CubicalComplex
from sklearn.base import BaseEstimator, TransformerMixin
+import numpy as np
# joblib is required by scikit-learn
from joblib import Parallel, delayed
@@ -33,7 +34,7 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
def __init__(
self,
newshape=None,
- persistence_dimension=-1,
+ homology_dimensions=-1,
homology_coeff_field=11,
min_persistence=0.0,
n_jobs=None,
@@ -42,18 +43,20 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
Constructor for the CubicalPersistence class.
Parameters:
- newshape (list of int): A list of number of top dimensional cells if cells filtration values will require
- to be reshaped (cf. :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform`)
- persistence_dimension (int or list of int): The returned persistence diagrams dimension(s).
+ newshape (tuple of ints): If cells filtration values require to be reshaped
+ (cf. :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform`), set `newshape`
+ to perform `numpy.reshape(X, newshape, order='C'` in
+ :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform` method.
+ homology_dimensions (int or list of int): The returned persistence diagrams dimension(s).
Short circuit the use of :class:`~gudhi.representations.preprocessing.DimensionSelector` when only one
- dimension matters (in other words, when `persistence_dimension` is an int).
+ dimension matters (in other words, when `homology_dimensions` is an int).
homology_coeff_field (int): The homology coefficient field. Must be a prime number. Default value is 11.
min_persistence (float): The minimum persistence value to take into account (strictly greater than
`min_persistence`). Default value is `0.0`. Set `min_persistence` to `-1.0` to see all values.
n_jobs (int): cf. https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html
"""
self.newshape = newshape
- self.persistence_dimension = persistence_dimension
+ self.homology_dimensions = homology_dimensions
self.homology_coeff_field = homology_coeff_field
self.min_persistence = min_persistence
self.n_jobs = n_jobs
@@ -65,37 +68,38 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
return self
def __transform(self, cells):
- cubical_complex = CubicalComplex(top_dimensional_cells=cells, dimensions=self.newshape)
+ cubical_complex = CubicalComplex(top_dimensional_cells=cells)
cubical_complex.compute_persistence(
homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence
)
return [
- cubical_complex.persistence_intervals_in_dimension(dim) for dim in self.persistence_dimension
+ cubical_complex.persistence_intervals_in_dimension(dim) for dim in self.homology_dimensions
]
def __transform_only_this_dim(self, cells):
- cubical_complex = CubicalComplex(top_dimensional_cells=cells, dimensions=self.newshape)
+ cubical_complex = CubicalComplex(top_dimensional_cells=cells)
cubical_complex.compute_persistence(
homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence
)
- return cubical_complex.persistence_intervals_in_dimension(self.persistence_dimension)
+ return cubical_complex.persistence_intervals_in_dimension(self.homology_dimensions)
def transform(self, X, Y=None):
"""Compute all the cubical complexes and their associated persistence diagrams.
- :param X: List of cells filtration values that should be flatten if `newshape` is set in the constructor, or
- already with the correct shape in a numpy.ndarray (and `newshape` must not be set).
+ :param X: List of cells filtration values (`numpy.reshape(X, newshape, order='C'` if `newshape` is set with a tuple of ints).
:type X: list of list of float OR list of numpy.ndarray
:return: Persistence diagrams in the format:
- - If `persistence_dimension` was set to `n`: `[array( Hn(X[0]) ), array( Hn(X[1]) ), ...]`
- - If `persistence_dimension` was set to `[i, j]`: `[[array( Hi(X[0]) ), array( Hj(X[0]) )], [array( Hi(X[1]) ), array( Hj(X[1]) )], ...]`
+ - If `homology_dimensions` was set to `n`: `[array( Hn(X[0]) ), array( Hn(X[1]) ), ...]`
+ - If `homology_dimensions` was set to `[i, j]`: `[[array( Hi(X[0]) ), array( Hj(X[0]) )], [array( Hi(X[1]) ), array( Hj(X[1]) )], ...]`
:rtype: list of tuple or list of list of tuple
"""
-
- # Depends on persistence_dimension is an integer or a list of integer (else case)
- if isinstance(self.persistence_dimension, int):
+ if self.newshape is not None:
+ X = np.reshape(X, self.newshape, order='C')
+
+ # Depends on homology_dimensions is an integer or a list of integer (else case)
+ if isinstance(self.homology_dimensions, int):
# threads is preferred as cubical construction and persistence computation releases the GIL
return Parallel(n_jobs=self.n_jobs, prefer="threads")(
delayed(self.__transform_only_this_dim)(cells) for cells in X
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
index 69b65dde..1c05a215 100644
--- a/src/python/test/test_sklearn_cubical_persistence.py
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -17,9 +17,9 @@ CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0.0, 6.0], [0.0, 8.0], [0.0, np.inf]])
def test_simple_constructor_from_top_cells():
cells = datasets.load_digits().images[0]
- cp = CubicalPersistence(persistence_dimension=0)
+ cp = CubicalPersistence(homology_dimensions=0)
np.testing.assert_array_equal(cp._CubicalPersistence__transform_only_this_dim(cells), CUBICAL_PERSISTENCE_H0_IMG0)
- cp = CubicalPersistence(persistence_dimension=[0, 2])
+ cp = CubicalPersistence(homology_dimensions=[0, 2])
diags = cp._CubicalPersistence__transform(cells)
assert len(diags) == 2
np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
@@ -27,13 +27,13 @@ def test_simple_constructor_from_top_cells():
def test_simple_constructor_from_top_cells_list():
digits = datasets.load_digits().images[:10]
- cp = CubicalPersistence(persistence_dimension=0, n_jobs=-2)
+ cp = CubicalPersistence(homology_dimensions=0, n_jobs=-2)
diags = cp.fit_transform(digits)
assert len(diags) == 10
np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
- cp = CubicalPersistence(persistence_dimension=[0, 1], n_jobs=-1)
+ cp = CubicalPersistence(homology_dimensions=[0, 1], n_jobs=-1)
diagsH0H1 = cp.fit_transform(digits)
assert len(diagsH0H1) == 10
for idx in range(10):
@@ -42,9 +42,18 @@ def test_simple_constructor_from_top_cells_list():
def test_simple_constructor_from_flattened_cells():
cells = datasets.load_digits().images[0]
# Not squared (extended) flatten cells
- cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten()
+ flat_cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten()
- cp = CubicalPersistence(persistence_dimension=0, newshape=[10, 8])
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
+ diags = cp.fit_transform([flat_cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+ # Not squared (extended) non-flatten cells
+ cells = np.hstack((cells, np.zeros((cells.shape[0], 2))))
+
+ # The aim of this second part of the test is to resize even if not mandatory
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
diags = cp.fit_transform([cells])
np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)