summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
Diffstat (limited to 'src/python')
-rw-r--r--src/python/CMakeLists.txt1
-rw-r--r--src/python/gudhi/sklearn/__init__.py0
-rw-r--r--src/python/gudhi/sklearn/cubical_persistence.py64
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py35
4 files changed, 100 insertions, 0 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index a1440cbc..9c19a5e7 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -262,6 +262,7 @@ if(PYTHONINTERP_FOUND)
file(COPY "gudhi/weighted_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/dtm_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/hera/__init__.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/hera")
+ file(COPY "gudhi/sklearn" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/")
# Some files for pip package
file(COPY "introduction.rst" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/")
diff --git a/src/python/gudhi/sklearn/__init__.py b/src/python/gudhi/sklearn/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/python/gudhi/sklearn/__init__.py
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py
new file mode 100644
index 00000000..a58fa77c
--- /dev/null
+++ b/src/python/gudhi/sklearn/cubical_persistence.py
@@ -0,0 +1,64 @@
+from .. import CubicalComplex
+from sklearn.base import TransformerMixin
+
+class CubicalPersistence(TransformerMixin):
+ # Fast way to find primes and should be enough
+ available_primes_ = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]
+ """
+ This is a class for computing the persistence diagrams from a cubical complex.
+ """
+ def __init__(self, dimensions=None, persistence_dim=0, min_persistence=0):
+ """
+ Constructor for the CubicalPersistence class.
+
+ Parameters:
+ dimensions (list of int): A list of number of top dimensional cells.
+ persistence_dim (int): The returned persistence diagrams dimension. Default value is `0`.
+ min_persistence (float): The minimum persistence value to take into account (strictly greater than
+ `min_persistence`). Default value is `0.0`. Sets `min_persistence` to `-1.0` to see all values.
+ """
+ self.dimensions_ = dimensions
+ self.persistence_dim_ = persistence_dim
+
+ self.homology_coeff_field_ = None
+ for dim in self.available_primes_:
+ if dim > persistence_dim + 1:
+ self.homology_coeff_field_ = dim
+ break
+ if self.homology_coeff_field_ == None:
+ raise ValueError("persistence_dim must be less than 96")
+
+ self.min_persistence_ = min_persistence
+
+ def transform(self, X):
+ """
+ Compute all the cubical complexes and their persistence diagrams.
+
+ Parameters:
+ X (list of double OR numpy.ndarray): Cells filtration values.
+
+ Returns:
+ Persistence diagrams
+ """
+ cubical_complex = CubicalComplex(top_dimensional_cells = X,
+ dimensions = self.dimensions_)
+ cubical_complex.compute_persistence(homology_coeff_field = self.homology_coeff_field_,
+ min_persistence = self.min_persistence_)
+ self.diagrams_ = cubical_complex.persistence_intervals_in_dimension(self.persistence_dim_)
+ if self.persistence_dim_ == 0:
+ # return all but the last, always [ 0., inf]
+ self.diagrams_ = self.diagrams_[:-1]
+ return self.diagrams_
+
+ def fit_transform(self, X):
+ """
+ Compute all the cubical complexes and their persistence diagrams.
+
+ Parameters:
+ X (list of double OR numpy.ndarray): Cells filtration values.
+
+ Returns:
+ Persistence diagrams
+ """
+ self.transform(X)
+ return self.diagrams_
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
new file mode 100644
index 00000000..f147ffe3
--- /dev/null
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -0,0 +1,35 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.sklearn.cubical_persistence import CubicalPersistence
+import numpy as np
+
+__author__ = "Vincent Rouvreau"
+__copyright__ = "Copyright (C) 2021 Inria"
+__license__ = "MIT"
+
+def test_simple_constructor_from_top_cells():
+ cp = CubicalPersistence(persistence_dim = 0)
+
+ # The first "0" from sklearn.datasets.load_digits()
+ bmp = np.array([[ 0., 0., 5., 13., 9., 1., 0., 0.],
+ [ 0., 0., 13., 15., 10., 15., 5., 0.],
+ [ 0., 3., 15., 2., 0., 11., 8., 0.],
+ [ 0., 4., 12., 0., 0., 8., 8., 0.],
+ [ 0., 5., 8., 0., 0., 9., 8., 0.],
+ [ 0., 4., 11., 0., 1., 12., 7., 0.],
+ [ 0., 2., 14., 5., 10., 12., 0., 0.],
+ [ 0., 0., 6., 13., 10., 0., 0., 0.]])
+
+ assert cp.fit_transform(bmp) == np.array([[0., 6.], [0., 8.]])
+
+# from gudhi.representations import PersistenceImage
+# PersistenceImage(bandwidth=50, weight=lambda x: x[1]**2, im_range=[0,256,0,256], resolution=[20, 20])
+# PI.fit_transform([diag]) \ No newline at end of file