summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-05-31 12:23:29 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-05-31 12:23:29 +0200
commit5b75186ace327ddc17eb6f06c0ba2485c93235ec (patch)
treef263e4e6a1f0034ec30ef7f999eab81942373346
parent8859128da7386955b00658ff5d71659a5de08c46 (diff)
code review + parallelization
-rw-r--r--src/python/CMakeLists.txt5
-rw-r--r--src/python/gudhi/sklearn/cubical_persistence.py52
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py25
3 files changed, 45 insertions, 37 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 9c19a5e7..727efbdb 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -536,6 +536,11 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_dtm_rips_complex)
endif()
+ # sklearn
+ if(SKLEARN_FOUND)
+ add_gudhi_py_test(test_sklearn_cubical_persistence)
+ endif()
+
# Set missing or not modules
set(GUDHI_MODULES ${GUDHI_MODULES} "python" CACHE INTERNAL "GUDHI_MODULES")
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py
index a58fa77c..809f5d4b 100644
--- a/src/python/gudhi/sklearn/cubical_persistence.py
+++ b/src/python/gudhi/sklearn/cubical_persistence.py
@@ -1,13 +1,15 @@
from .. import CubicalComplex
-from sklearn.base import TransformerMixin
+from sklearn.base import BaseEstimator, TransformerMixin
+# joblib is required by scikit-learn
+from joblib import Parallel, delayed
-class CubicalPersistence(TransformerMixin):
+class CubicalPersistence(BaseEstimator, TransformerMixin):
# Fast way to find primes and should be enough
available_primes_ = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]
"""
This is a class for computing the persistence diagrams from a cubical complex.
"""
- def __init__(self, dimensions=None, persistence_dim=0, min_persistence=0):
+ def __init__(self, dimensions=None, persistence_dim=0, min_persistence=0, n_jobs=None):
"""
Constructor for the CubicalPersistence class.
@@ -16,9 +18,10 @@ class CubicalPersistence(TransformerMixin):
persistence_dim (int): The returned persistence diagrams dimension. Default value is `0`.
min_persistence (float): The minimum persistence value to take into account (strictly greater than
`min_persistence`). Default value is `0.0`. Sets `min_persistence` to `-1.0` to see all values.
+ n_jobs (int): cf. https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html
"""
- self.dimensions_ = dimensions
- self.persistence_dim_ = persistence_dim
+ self.dimensions = dimensions
+ self.persistence_dim = persistence_dim
self.homology_coeff_field_ = None
for dim in self.available_primes_:
@@ -28,37 +31,36 @@ class CubicalPersistence(TransformerMixin):
if self.homology_coeff_field_ == None:
raise ValueError("persistence_dim must be less than 96")
- self.min_persistence_ = min_persistence
+ self.min_persistence = min_persistence
+ self.n_jobs = n_jobs
- def transform(self, X):
+ def fit(self, X, Y=None):
"""
- Compute all the cubical complexes and their persistence diagrams.
-
- Parameters:
- X (list of double OR numpy.ndarray): Cells filtration values.
-
- Returns:
- Persistence diagrams
+ Nothing to be done.
"""
- cubical_complex = CubicalComplex(top_dimensional_cells = X,
- dimensions = self.dimensions_)
+ return self
+
+ def __transform(self, cells):
+ cubical_complex = CubicalComplex(top_dimensional_cells = cells, dimensions = self.dimensions)
cubical_complex.compute_persistence(homology_coeff_field = self.homology_coeff_field_,
- min_persistence = self.min_persistence_)
- self.diagrams_ = cubical_complex.persistence_intervals_in_dimension(self.persistence_dim_)
- if self.persistence_dim_ == 0:
+ min_persistence = self.min_persistence)
+ diagrams = cubical_complex.persistence_intervals_in_dimension(self.persistence_dim)
+ if self.persistence_dim == 0:
# return all but the last, always [ 0., inf]
- self.diagrams_ = self.diagrams_[:-1]
- return self.diagrams_
+ diagrams = diagrams[:-1]
+ return diagrams
- def fit_transform(self, X):
+ def transform(self, X, Y=None):
"""
Compute all the cubical complexes and their persistence diagrams.
Parameters:
- X (list of double OR numpy.ndarray): Cells filtration values.
+ X (list of list of double OR list of numpy.ndarray): List of cells filtration values.
Returns:
Persistence diagrams
"""
- self.transform(X)
- return self.diagrams_
+
+ # threads is preferred as cubical construction and persistence computation releases the GIL
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(
+ delayed(self.__transform)(cells) for cells in X)
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
index f147ffe3..134611c9 100644
--- a/src/python/test/test_sklearn_cubical_persistence.py
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -10,26 +10,27 @@
from gudhi.sklearn.cubical_persistence import CubicalPersistence
import numpy as np
+from sklearn import datasets
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2021 Inria"
__license__ = "MIT"
def test_simple_constructor_from_top_cells():
+ cells = datasets.load_digits().images[0]
cp = CubicalPersistence(persistence_dim = 0)
+ np.testing.assert_array_equal(cp._CubicalPersistence__transform(cells),
+ np.array([[0., 6.], [0., 8.]]))
- # The first "0" from sklearn.datasets.load_digits()
- bmp = np.array([[ 0., 0., 5., 13., 9., 1., 0., 0.],
- [ 0., 0., 13., 15., 10., 15., 5., 0.],
- [ 0., 3., 15., 2., 0., 11., 8., 0.],
- [ 0., 4., 12., 0., 0., 8., 8., 0.],
- [ 0., 5., 8., 0., 0., 9., 8., 0.],
- [ 0., 4., 11., 0., 1., 12., 7., 0.],
- [ 0., 2., 14., 5., 10., 12., 0., 0.],
- [ 0., 0., 6., 13., 10., 0., 0., 0.]])
+def test_simple_constructor_from_top_cells_list():
+ digits = datasets.load_digits().images[:10]
+ cp = CubicalPersistence(persistence_dim = 0, n_jobs=-2)
- assert cp.fit_transform(bmp) == np.array([[0., 6.], [0., 8.]])
+ diags = cp.fit_transform(digits)
+ assert len(diags) == 10
+ np.testing.assert_array_equal(diags[0],
+ np.array([[0., 6.], [0., 8.]]))
# from gudhi.representations import PersistenceImage
-# PersistenceImage(bandwidth=50, weight=lambda x: x[1]**2, im_range=[0,256,0,256], resolution=[20, 20])
-# PI.fit_transform([diag]) \ No newline at end of file
+# pi = PersistenceImage(bandwidth=50, weight=lambda x: x[1]**2, im_range=[0,256,0,256], resolution=[20, 20])
+# pi.fit_transform(diags) \ No newline at end of file