summaryrefslogtreecommitdiff
path: root/src/python/gudhi/sklearn
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-07-02 13:51:10 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-07-02 13:51:10 +0200
commit91d72a69f2f04676fbd671af3dc2f3040c9f1c48 (patch)
tree7a038d1b8d723ebb38d32a7d79ad2b7c02fe8baa /src/python/gudhi/sklearn
parentf2b8bbccd5cbfa2a0fbb23bdb72e965196a2c05c (diff)
code review: bad homology_coeff_field management
Diffstat (limited to 'src/python/gudhi/sklearn')
-rw-r--r--src/python/gudhi/sklearn/cubical_persistence.py17
1 files changed, 4 insertions, 13 deletions
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py
index 251e240f..9af683d7 100644
--- a/src/python/gudhi/sklearn/cubical_persistence.py
+++ b/src/python/gudhi/sklearn/cubical_persistence.py
@@ -15,13 +15,11 @@ from joblib import Parallel, delayed
class CubicalPersistence(BaseEstimator, TransformerMixin):
- # Fast way to find primes and should be enough
- _available_primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]
"""
This is a class for computing the persistence diagrams from a cubical complex.
"""
- def __init__(self, dimensions=None, persistence_dim=0, min_persistence=0, n_jobs=None):
+ def __init__(self, dimensions=None, persistence_dim=0, homology_coeff_field=11, min_persistence=0., n_jobs=None):
"""
Constructor for the CubicalPersistence class.
@@ -29,21 +27,14 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
dimensions (list of int): A list of number of top dimensional cells if cells filtration values will require
to be reshaped (cf. :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform`)
persistence_dim (int): The returned persistence diagrams dimension. Default value is `0`.
+ homology_coeff_field (int): The homology coefficient field. Must be a prime number. Default value is 11.
min_persistence (float): The minimum persistence value to take into account (strictly greater than
`min_persistence`). Default value is `0.0`. Sets `min_persistence` to `-1.0` to see all values.
n_jobs (int): cf. https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html
"""
self.dimensions = dimensions
self.persistence_dim = persistence_dim
-
- self.homology_coeff_field_ = None
- for dim in self._available_primes:
- if dim > persistence_dim + 1:
- self.homology_coeff_field_ = dim
- break
- if self.homology_coeff_field_ == None:
- raise ValueError("persistence_dim must be less than 96")
-
+ self.homology_coeff_field = homology_coeff_field
self.min_persistence = min_persistence
self.n_jobs = n_jobs
@@ -56,7 +47,7 @@ class CubicalPersistence(BaseEstimator, TransformerMixin):
def __transform(self, cells):
cubical_complex = CubicalComplex(top_dimensional_cells=cells, dimensions=self.dimensions)
cubical_complex.compute_persistence(
- homology_coeff_field=self.homology_coeff_field_, min_persistence=self.min_persistence
+ homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence
)
diagrams = cubical_complex.persistence_intervals_in_dimension(self.persistence_dim)
return diagrams