diff options
Diffstat (limited to 'src/python')
-rw-r--r-- | src/python/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/python/gudhi/sklearn/__init__.py | 0 | ||||
-rw-r--r-- | src/python/gudhi/sklearn/cubical_persistence.py | 64 | ||||
-rw-r--r-- | src/python/test/test_sklearn_cubical_persistence.py | 35 |
4 files changed, 100 insertions, 0 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index a1440cbc..9c19a5e7 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -262,6 +262,7 @@ if(PYTHONINTERP_FOUND) file(COPY "gudhi/weighted_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") file(COPY "gudhi/dtm_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") file(COPY "gudhi/hera/__init__.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/hera") + file(COPY "gudhi/sklearn" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/") # Some files for pip package file(COPY "introduction.rst" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/") diff --git a/src/python/gudhi/sklearn/__init__.py b/src/python/gudhi/sklearn/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/python/gudhi/sklearn/__init__.py diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py new file mode 100644 index 00000000..a58fa77c --- /dev/null +++ b/src/python/gudhi/sklearn/cubical_persistence.py @@ -0,0 +1,64 @@ +from .. import CubicalComplex +from sklearn.base import TransformerMixin + +class CubicalPersistence(TransformerMixin): + # Fast way to find primes and should be enough + available_primes_ = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97] + """ + This is a class for computing the persistence diagrams from a cubical complex. + """ + def __init__(self, dimensions=None, persistence_dim=0, min_persistence=0): + """ + Constructor for the CubicalPersistence class. + + Parameters: + dimensions (list of int): A list of number of top dimensional cells. + persistence_dim (int): The returned persistence diagrams dimension. Default value is `0`. + min_persistence (float): The minimum persistence value to take into account (strictly greater than + `min_persistence`). Default value is `0.0`. Sets `min_persistence` to `-1.0` to see all values. + """ + self.dimensions_ = dimensions + self.persistence_dim_ = persistence_dim + + self.homology_coeff_field_ = None + for dim in self.available_primes_: + if dim > persistence_dim + 1: + self.homology_coeff_field_ = dim + break + if self.homology_coeff_field_ == None: + raise ValueError("persistence_dim must be less than 96") + + self.min_persistence_ = min_persistence + + def transform(self, X): + """ + Compute all the cubical complexes and their persistence diagrams. + + Parameters: + X (list of double OR numpy.ndarray): Cells filtration values. + + Returns: + Persistence diagrams + """ + cubical_complex = CubicalComplex(top_dimensional_cells = X, + dimensions = self.dimensions_) + cubical_complex.compute_persistence(homology_coeff_field = self.homology_coeff_field_, + min_persistence = self.min_persistence_) + self.diagrams_ = cubical_complex.persistence_intervals_in_dimension(self.persistence_dim_) + if self.persistence_dim_ == 0: + # return all but the last, always [ 0., inf] + self.diagrams_ = self.diagrams_[:-1] + return self.diagrams_ + + def fit_transform(self, X): + """ + Compute all the cubical complexes and their persistence diagrams. + + Parameters: + X (list of double OR numpy.ndarray): Cells filtration values. + + Returns: + Persistence diagrams + """ + self.transform(X) + return self.diagrams_ diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py new file mode 100644 index 00000000..f147ffe3 --- /dev/null +++ b/src/python/test/test_sklearn_cubical_persistence.py @@ -0,0 +1,35 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Vincent Rouvreau + + Copyright (C) 2021 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +from gudhi.sklearn.cubical_persistence import CubicalPersistence +import numpy as np + +__author__ = "Vincent Rouvreau" +__copyright__ = "Copyright (C) 2021 Inria" +__license__ = "MIT" + +def test_simple_constructor_from_top_cells(): + cp = CubicalPersistence(persistence_dim = 0) + + # The first "0" from sklearn.datasets.load_digits() + bmp = np.array([[ 0., 0., 5., 13., 9., 1., 0., 0.], + [ 0., 0., 13., 15., 10., 15., 5., 0.], + [ 0., 3., 15., 2., 0., 11., 8., 0.], + [ 0., 4., 12., 0., 0., 8., 8., 0.], + [ 0., 5., 8., 0., 0., 9., 8., 0.], + [ 0., 4., 11., 0., 1., 12., 7., 0.], + [ 0., 2., 14., 5., 10., 12., 0., 0.], + [ 0., 0., 6., 13., 10., 0., 0., 0.]]) + + assert cp.fit_transform(bmp) == np.array([[0., 6.], [0., 8.]]) + +# from gudhi.representations import PersistenceImage +# PersistenceImage(bandwidth=50, weight=lambda x: x[1]**2, im_range=[0,256,0,256], resolution=[20, 20]) +# PI.fit_transform([diag])
\ No newline at end of file |