summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-07 09:36:19 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-07 09:36:19 +0200
commit8813c23e4931e9c955dd0e89547133065429ae0d (patch)
tree918bb6fe007190f539bfd1b9f663f19f88dece45
parent546b059af6c0581d06bfe9cebbe853f2f7bd4589 (diff)
format file
-rw-r--r--src/python/gudhi/sklearn/cubical_persistence.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py
index a7a3d036..f4341bf6 100644
--- a/src/python/gudhi/sklearn/cubical_persistence.py
+++ b/src/python/gudhi/sklearn/cubical_persistence.py
@@ -1,14 +1,17 @@
from .. import CubicalComplex
from sklearn.base import BaseEstimator, TransformerMixin
+
# joblib is required by scikit-learn
from joblib import Parallel, delayed
+
class CubicalPersistence(BaseEstimator, TransformerMixin):
# Fast way to find primes and should be enough
_available_primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]
"""
This is a class for computing the persistence diagrams from a cubical complex.
"""
+
def __init__(self, dimensions=None, persistence_dim=0, min_persistence=0, n_jobs=None):
"""
Constructor for the CubicalPersistence class.
@@ -41,9 +44,10 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
return self
def __transform(self, cells):
- cubical_complex = CubicalComplex(top_dimensional_cells = cells, dimensions = self.dimensions)
- cubical_complex.compute_persistence(homology_coeff_field = self.homology_coeff_field_,
- min_persistence = self.min_persistence)
+ cubical_complex = CubicalComplex(top_dimensional_cells=cells, dimensions=self.dimensions)
+ cubical_complex.compute_persistence(
+ homology_coeff_field=self.homology_coeff_field_, min_persistence=self.min_persistence
+ )
diagrams = cubical_complex.persistence_intervals_in_dimension(self.persistence_dim)
return diagrams
@@ -59,5 +63,4 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
"""
# threads is preferred as cubical construction and persistence computation releases the GIL
- return Parallel(n_jobs=self.n_jobs, prefer="threads")(
- delayed(self.__transform)(cells) for cells in X)
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(cells) for cells in X)