diff options
author | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-05-31 12:23:29 +0200 |
---|---|---|
committer | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-05-31 12:23:29 +0200 |
commit | 5b75186ace327ddc17eb6f06c0ba2485c93235ec (patch) | |
tree | f263e4e6a1f0034ec30ef7f999eab81942373346 | |
parent | 8859128da7386955b00658ff5d71659a5de08c46 (diff) |
code review + parallelization
-rw-r--r-- | src/python/CMakeLists.txt | 5 | ||||
-rw-r--r-- | src/python/gudhi/sklearn/cubical_persistence.py | 52 | ||||
-rw-r--r-- | src/python/test/test_sklearn_cubical_persistence.py | 25 |
3 files changed, 45 insertions, 37 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 9c19a5e7..727efbdb 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -536,6 +536,11 @@ if(PYTHONINTERP_FOUND) add_gudhi_py_test(test_dtm_rips_complex) endif() + # sklearn + if(SKLEARN_FOUND) + add_gudhi_py_test(test_sklearn_cubical_persistence) + endif() + # Set missing or not modules set(GUDHI_MODULES ${GUDHI_MODULES} "python" CACHE INTERNAL "GUDHI_MODULES") diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py index a58fa77c..809f5d4b 100644 --- a/src/python/gudhi/sklearn/cubical_persistence.py +++ b/src/python/gudhi/sklearn/cubical_persistence.py @@ -1,13 +1,15 @@ from .. import CubicalComplex -from sklearn.base import TransformerMixin +from sklearn.base import BaseEstimator, TransformerMixin +# joblib is required by scikit-learn +from joblib import Parallel, delayed -class CubicalPersistence(TransformerMixin): +class CubicalPersistence(BaseEstimator, TransformerMixin): # Fast way to find primes and should be enough available_primes_ = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97] """ This is a class for computing the persistence diagrams from a cubical complex. """ - def __init__(self, dimensions=None, persistence_dim=0, min_persistence=0): + def __init__(self, dimensions=None, persistence_dim=0, min_persistence=0, n_jobs=None): """ Constructor for the CubicalPersistence class. @@ -16,9 +18,10 @@ class CubicalPersistence(TransformerMixin): persistence_dim (int): The returned persistence diagrams dimension. Default value is `0`. min_persistence (float): The minimum persistence value to take into account (strictly greater than `min_persistence`). Default value is `0.0`. Sets `min_persistence` to `-1.0` to see all values. + n_jobs (int): cf. https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html """ - self.dimensions_ = dimensions - self.persistence_dim_ = persistence_dim + self.dimensions = dimensions + self.persistence_dim = persistence_dim self.homology_coeff_field_ = None for dim in self.available_primes_: @@ -28,37 +31,36 @@ class CubicalPersistence(TransformerMixin): if self.homology_coeff_field_ == None: raise ValueError("persistence_dim must be less than 96") - self.min_persistence_ = min_persistence + self.min_persistence = min_persistence + self.n_jobs = n_jobs - def transform(self, X): + def fit(self, X, Y=None): """ - Compute all the cubical complexes and their persistence diagrams. - - Parameters: - X (list of double OR numpy.ndarray): Cells filtration values. - - Returns: - Persistence diagrams + Nothing to be done. """ - cubical_complex = CubicalComplex(top_dimensional_cells = X, - dimensions = self.dimensions_) + return self + + def __transform(self, cells): + cubical_complex = CubicalComplex(top_dimensional_cells = cells, dimensions = self.dimensions) cubical_complex.compute_persistence(homology_coeff_field = self.homology_coeff_field_, - min_persistence = self.min_persistence_) - self.diagrams_ = cubical_complex.persistence_intervals_in_dimension(self.persistence_dim_) - if self.persistence_dim_ == 0: + min_persistence = self.min_persistence) + diagrams = cubical_complex.persistence_intervals_in_dimension(self.persistence_dim) + if self.persistence_dim == 0: # return all but the last, always [ 0., inf] - self.diagrams_ = self.diagrams_[:-1] - return self.diagrams_ + diagrams = diagrams[:-1] + return diagrams - def fit_transform(self, X): + def transform(self, X, Y=None): """ Compute all the cubical complexes and their persistence diagrams. Parameters: - X (list of double OR numpy.ndarray): Cells filtration values. + X (list of list of double OR list of numpy.ndarray): List of cells filtration values. Returns: Persistence diagrams """ - self.transform(X) - return self.diagrams_ + + # threads is preferred as cubical construction and persistence computation releases the GIL + return Parallel(n_jobs=self.n_jobs, prefer="threads")( + delayed(self.__transform)(cells) for cells in X) diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py index f147ffe3..134611c9 100644 --- a/src/python/test/test_sklearn_cubical_persistence.py +++ b/src/python/test/test_sklearn_cubical_persistence.py @@ -10,26 +10,27 @@ from gudhi.sklearn.cubical_persistence import CubicalPersistence import numpy as np +from sklearn import datasets __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2021 Inria" __license__ = "MIT" def test_simple_constructor_from_top_cells(): + cells = datasets.load_digits().images[0] cp = CubicalPersistence(persistence_dim = 0) + np.testing.assert_array_equal(cp._CubicalPersistence__transform(cells), + np.array([[0., 6.], [0., 8.]])) - # The first "0" from sklearn.datasets.load_digits() - bmp = np.array([[ 0., 0., 5., 13., 9., 1., 0., 0.], - [ 0., 0., 13., 15., 10., 15., 5., 0.], - [ 0., 3., 15., 2., 0., 11., 8., 0.], - [ 0., 4., 12., 0., 0., 8., 8., 0.], - [ 0., 5., 8., 0., 0., 9., 8., 0.], - [ 0., 4., 11., 0., 1., 12., 7., 0.], - [ 0., 2., 14., 5., 10., 12., 0., 0.], - [ 0., 0., 6., 13., 10., 0., 0., 0.]]) +def test_simple_constructor_from_top_cells_list(): + digits = datasets.load_digits().images[:10] + cp = CubicalPersistence(persistence_dim = 0, n_jobs=-2) - assert cp.fit_transform(bmp) == np.array([[0., 6.], [0., 8.]]) + diags = cp.fit_transform(digits) + assert len(diags) == 10 + np.testing.assert_array_equal(diags[0], + np.array([[0., 6.], [0., 8.]])) # from gudhi.representations import PersistenceImage -# PersistenceImage(bandwidth=50, weight=lambda x: x[1]**2, im_range=[0,256,0,256], resolution=[20, 20]) -# PI.fit_transform([diag])
\ No newline at end of file +# pi = PersistenceImage(bandwidth=50, weight=lambda x: x[1]**2, im_range=[0,256,0,256], resolution=[20, 20]) +# pi.fit_transform(diags)
\ No newline at end of file |