diff options
author | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-06-07 09:36:19 +0200 |
---|---|---|
committer | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-06-07 09:36:19 +0200 |
commit | 8813c23e4931e9c955dd0e89547133065429ae0d (patch) | |
tree | 918bb6fe007190f539bfd1b9f663f19f88dece45 | |
parent | 546b059af6c0581d06bfe9cebbe853f2f7bd4589 (diff) |
format file
-rw-r--r-- | src/python/gudhi/sklearn/cubical_persistence.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py index a7a3d036..f4341bf6 100644 --- a/src/python/gudhi/sklearn/cubical_persistence.py +++ b/src/python/gudhi/sklearn/cubical_persistence.py @@ -1,14 +1,17 @@ from .. import CubicalComplex from sklearn.base import BaseEstimator, TransformerMixin + # joblib is required by scikit-learn from joblib import Parallel, delayed + class CubicalPersistence(BaseEstimator, TransformerMixin): # Fast way to find primes and should be enough _available_primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97] """ This is a class for computing the persistence diagrams from a cubical complex. """ + def __init__(self, dimensions=None, persistence_dim=0, min_persistence=0, n_jobs=None): """ Constructor for the CubicalPersistence class. @@ -41,9 +44,10 @@ class CubicalPersistence(BaseEstimator, TransformerMixin): return self def __transform(self, cells): - cubical_complex = CubicalComplex(top_dimensional_cells = cells, dimensions = self.dimensions) - cubical_complex.compute_persistence(homology_coeff_field = self.homology_coeff_field_, - min_persistence = self.min_persistence) + cubical_complex = CubicalComplex(top_dimensional_cells=cells, dimensions=self.dimensions) + cubical_complex.compute_persistence( + homology_coeff_field=self.homology_coeff_field_, min_persistence=self.min_persistence + ) diagrams = cubical_complex.persistence_intervals_in_dimension(self.persistence_dim) return diagrams @@ -59,5 +63,4 @@ class CubicalPersistence(BaseEstimator, TransformerMixin): """ # threads is preferred as cubical construction and persistence computation releases the GIL - return Parallel(n_jobs=self.n_jobs, prefer="threads")( - delayed(self.__transform)(cells) for cells in X) + return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(cells) for cells in X) |