summaryrefslogtreecommitdiff
path: root/src/python/gudhi/sklearn/cubical_persistence.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/sklearn/cubical_persistence.py')
-rw-r--r--src/python/gudhi/sklearn/cubical_persistence.py40
1 files changed, 19 insertions, 21 deletions
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py
index 329c9435..454cdd07 100644
--- a/src/python/gudhi/sklearn/cubical_persistence.py
+++ b/src/python/gudhi/sklearn/cubical_persistence.py
@@ -33,8 +33,7 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
def __init__(
self,
dimensions=None,
- max_persistence_dimension=0,
- only_this_dim=-1,
+ persistence_dimension=-1,
homology_coeff_field=11,
min_persistence=0.0,
n_jobs=None,
@@ -45,20 +44,16 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
Parameters:
dimensions (list of int): A list of number of top dimensional cells if cells filtration values will require
to be reshaped (cf. :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform`)
- max_persistence_dimension (int): The returned persistence diagrams maximal dimension. Default value is `0`.
- Ignored if `only_this_dim` is set.
- only_this_dim (int): The returned persistence diagrams dimension. If `only_this_dim` is set,
- `max_persistence_dimension` will be ignored.
- Short circuit the use of :class:`~gudhi.sklearn.post_processing.DimensionSelector` when only one
- dimension matters.
+ persistence_dimension (int or list of int): The returned persistence diagrams dimension(s).
+ Short circuit the use of :class:`~gudhi.representations.preprocessing.DimensionSelector` when only one
+ dimension matters (in other words, when `persistence_dimension` is an int).
homology_coeff_field (int): The homology coefficient field. Must be a prime number. Default value is 11.
min_persistence (float): The minimum persistence value to take into account (strictly greater than
- `min_persistence`). Default value is `0.0`. Sets `min_persistence` to `-1.0` to see all values.
+ `min_persistence`). Default value is `0.0`. Set `min_persistence` to `-1.0` to see all values.
n_jobs (int): cf. https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html
"""
self.dimensions = dimensions
- self.max_persistence_dimension = max_persistence_dimension
- self.only_this_dim = only_this_dim
+ self.persistence_dimension = persistence_dimension
self.homology_coeff_field = homology_coeff_field
self.min_persistence = min_persistence
self.n_jobs = n_jobs
@@ -75,7 +70,7 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence
)
return [
- cubical_complex.persistence_intervals_in_dimension(dim) for dim in range(self.max_persistence_dimension + 1)
+ cubical_complex.persistence_intervals_in_dimension(dim) for dim in self.persistence_dimension
]
def __transform_only_this_dim(self, cells):
@@ -83,28 +78,31 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
cubical_complex.compute_persistence(
homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence
)
- return cubical_complex.persistence_intervals_in_dimension(self.only_this_dim)
+ return cubical_complex.persistence_intervals_in_dimension(self.persistence_dimension)
def transform(self, X, Y=None):
"""
Compute all the cubical complexes and their associated persistence diagrams.
Parameters:
- X (list of list of double OR list of numpy.ndarray): List of cells filtration values that can be flatten if
- `dimensions` is set in the constructor, or already with the correct shape in a numpy.ndarray (and
+ X (list of list of double OR list of numpy.ndarray): List of cells filtration values that should be flatten
+ if `dimensions` is set in the constructor, or already with the correct shape in a numpy.ndarray (and
`dimensions` must not be set).
Returns:
+ list of pairs or list of list of pairs:
Persistence diagrams in the format:
- - If `only_this_dim` was set to `n`: `[array( Hn(X[0]) ), array( Hn(X[1]) ), ...]`
- - else: `[[array( H0(X[0]) ), array( H1(X[0]) ), ...], [array( H0(X[1]) ), array( H1(X[1]) ), ...], ...]`
+ - If `persistence_dimension` was set to `n`: `[array( Hn(X[0]) ), array( Hn(X[1]) ), ...]`
+ - If `persistence_dimension` was set to `[i, j]`: `[[array( Hi(X[0]) ), array( Hj(X[0]) )], [array( Hi(X[1]) ), array( Hj(X[1]) )], ...]`
"""
- if self.only_this_dim == -1:
- # threads is preferred as cubical construction and persistence computation releases the GIL
- return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(cells) for cells in X)
- else:
+ # Depends on persistence_dimension is an integer or a list of integer (else case)
+ if isinstance(self.persistence_dimension, int):
# threads is preferred as cubical construction and persistence computation releases the GIL
return Parallel(n_jobs=self.n_jobs, prefer="threads")(
delayed(self.__transform_only_this_dim)(cells) for cells in X
)
+ else:
+ # threads is preferred as cubical construction and persistence computation releases the GIL
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(cells) for cells in X)
+