summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
authorwreise <wojciech.reise@epfl.ch>2022-10-07 18:14:30 +0200
committerwreise <wojciech.reise@epfl.ch>2022-10-07 18:14:30 +0200
commit2ec6c2457d44a06deb45f0243a9c587b284daeba (patch)
treea04b43dcdc4eb90a2727509cad5c659dd46d14bc /src/python
parentd9dfffdb580ab865a829fce851779f33fa47e4f7 (diff)
parent524718d63a8f633dbcc4fe7db3fe920ebd7e972c (diff)
Merge branch 'master' into optimize_silhouettes
Diffstat (limited to 'src/python')
-rw-r--r--src/python/CMakeLists.txt11
-rw-r--r--src/python/doc/cubical_complex_sklearn_itf_ref.rst102
-rw-r--r--src/python/doc/cubical_complex_sum.inc33
-rw-r--r--src/python/doc/cubical_complex_user.rst11
-rw-r--r--src/python/doc/differentiation_sum.inc4
-rw-r--r--src/python/doc/img/sklearn.pngbin0 -> 9368 bytes
-rw-r--r--src/python/doc/img/tensorflow.pngbin0 -> 3846 bytes
-rw-r--r--src/python/doc/persistent_cohomology_user.rst29
-rw-r--r--src/python/doc/rips_complex_sum.inc6
-rw-r--r--src/python/doc/rips_complex_user.rst8
-rw-r--r--src/python/doc/simplex_tree_sum.inc6
-rw-r--r--src/python/gudhi/representations/preprocessing.py57
-rw-r--r--src/python/gudhi/sklearn/__init__.py0
-rw-r--r--src/python/gudhi/sklearn/cubical_persistence.py110
-rw-r--r--src/python/gudhi/tensorflow/cubical_layer.py2
-rw-r--r--src/python/test/test_representations_preprocessing.py39
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py59
17 files changed, 424 insertions, 53 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 65adef75..5f323935 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -298,6 +298,7 @@ if(PYTHONINTERP_FOUND)
file(COPY "gudhi/dtm_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/hera/__init__.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/hera")
file(COPY "gudhi/datasets" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi" FILES_MATCHING PATTERN "*.py")
+ file(COPY "gudhi/sklearn" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/")
# Some files for pip package
@@ -574,6 +575,11 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_betti_curve_representations)
endif()
+ # Representations preprocessing
+ if(SKLEARN_FOUND)
+ add_gudhi_py_test(test_representations_preprocessing)
+ endif()
+
# Time Delay
add_gudhi_py_test(test_time_delay)
@@ -603,6 +609,11 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_remote_datasets)
endif()
+ # sklearn
+ if(SKLEARN_FOUND)
+ add_gudhi_py_test(test_sklearn_cubical_persistence)
+ endif()
+
# persistence graphical tools
if(MATPLOTLIB_FOUND)
add_gudhi_py_test(test_persistence_graphical_tools)
diff --git a/src/python/doc/cubical_complex_sklearn_itf_ref.rst b/src/python/doc/cubical_complex_sklearn_itf_ref.rst
new file mode 100644
index 00000000..90ae9ccd
--- /dev/null
+++ b/src/python/doc/cubical_complex_sklearn_itf_ref.rst
@@ -0,0 +1,102 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+Cubical complex persistence scikit-learn like interface
+#######################################################
+
+.. list-table::
+ :width: 100%
+ :header-rows: 0
+
+ * - :Since: GUDHI 3.6.0
+ - :License: MIT
+ - :Requires: `Scikit-learn <installation.html#scikit-learn>`_
+
+Cubical complex persistence scikit-learn like interface example
+---------------------------------------------------------------
+
+In this example, hand written digits are used as an input.
+a TDA scikit-learn pipeline is constructed and is composed of:
+
+#. :class:`~gudhi.sklearn.cubical_persistence.CubicalPersistence` that builds a cubical complex from the inputs and
+ returns its persistence diagrams
+#. :class:`~gudhi.representations.preprocessing.DiagramSelector` that removes non-finite persistence diagrams values
+#. :class:`~gudhi.representations.vector_methods.PersistenceImage` that builds the persistence images from persistence diagrams
+#. `SVC <https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html>`_ which is a scikit-learn support
+ vector classifier.
+
+This ML pipeline is trained to detect if the hand written digit is an '8' or not, thanks to the fact that an '8' has
+two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected components in :math:`\mathbf{H}_0`.
+
+.. code-block:: python
+
+ # Standard scientific Python imports
+ import numpy as np
+
+ # Standard scikit-learn imports
+ from sklearn.datasets import fetch_openml
+ from sklearn.pipeline import Pipeline
+ from sklearn.model_selection import train_test_split
+ from sklearn.svm import SVC
+ from sklearn import metrics
+
+ # Import TDA pipeline requirements
+ from gudhi.sklearn.cubical_persistence import CubicalPersistence
+ from gudhi.representations import PersistenceImage, DiagramSelector
+
+ X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
+
+ # Target is: "is an eight ?"
+ y = (y == "8") * 1
+ print("There are", np.sum(y), "eights out of", len(y), "numbers.")
+
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
+ pipe = Pipeline(
+ [
+ ("cub_pers", CubicalPersistence(homology_dimensions=0, newshape=[-1, 28, 28], n_jobs=-2)),
+ # Or for multiple persistence dimension computation
+ # ("cub_pers", CubicalPersistence(homology_dimensions=[0, 1], newshape=[-1, 28, 28])),
+ # ("H0_diags", DimensionSelector(index=0), # where index is the index in homology_dimensions array
+ ("finite_diags", DiagramSelector(use=True, point_type="finite")),
+ (
+ "pers_img",
+ PersistenceImage(bandwidth=50, weight=lambda x: x[1] ** 2, im_range=[0, 256, 0, 256], resolution=[20, 20]),
+ ),
+ ("svc", SVC()),
+ ]
+ )
+
+ # Learn from the train subset
+ pipe.fit(X_train, y_train)
+ # Predict from the test subset
+ predicted = pipe.predict(X_test)
+
+ print(f"Classification report for TDA pipeline {pipe}:\n" f"{metrics.classification_report(y_test, predicted)}\n")
+
+.. code-block:: none
+
+ There are 6825 eights out of 70000 numbers.
+ Classification report for TDA pipeline Pipeline(steps=[('cub_pers',
+ CubicalPersistence(newshape=[28, 28], n_jobs=-2)),
+ ('finite_diags', DiagramSelector(use=True)),
+ ('pers_img',
+ PersistenceImage(bandwidth=50, im_range=[0, 256, 0, 256],
+ weight=<function <lambda> at 0x7f3e54137ae8>)),
+ ('svc', SVC())]):
+ precision recall f1-score support
+
+ 0 0.97 0.99 0.98 25284
+ 1 0.92 0.68 0.78 2716
+
+ accuracy 0.96 28000
+ macro avg 0.94 0.84 0.88 28000
+ weighted avg 0.96 0.96 0.96 28000
+
+Cubical complex persistence scikit-learn like interface reference
+-----------------------------------------------------------------
+
+.. autoclass:: gudhi.sklearn.cubical_persistence.CubicalPersistence
+ :members:
+ :special-members: __init__
+ :show-inheritance: \ No newline at end of file
diff --git a/src/python/doc/cubical_complex_sum.inc b/src/python/doc/cubical_complex_sum.inc
index 90ec9fc2..b27843e5 100644
--- a/src/python/doc/cubical_complex_sum.inc
+++ b/src/python/doc/cubical_complex_sum.inc
@@ -1,17 +1,22 @@
.. table::
:widths: 30 40 30
- +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+
- | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko |
- | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | |
- | :alt: Cubical complex representation | | :Since: GUDHI 2.0.0 |
- | :figclass: align-center | | |
- | | | :License: MIT |
- | | | |
- +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+
- | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
- | | * :doc:`periodic_cubical_complex_ref` |
- +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+
- | | * :doc:`cubical_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
- | | | |
- +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko |
+ | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | :Since: GUDHI 2.0.0 |
+ | :alt: Cubical complex representation | | :License: MIT |
+ | :figclass: align-center | | |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
+ | | * :doc:`periodic_cubical_complex_ref` |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | .. image:: | * :doc:`cubical_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
+ | img/tensorflow.png | | |
+ | :target: https://www.tensorflow.org | | |
+ | :height: 30 | | |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | .. image:: | * :doc:`cubical_complex_sklearn_itf_ref` | :Requires: `Scikit-learn <installation.html#scikit-learn>`_ |
+ | img/sklearn.png | | |
+ | :target: https://scikit-learn.org | | |
+ | :height: 30 | | |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
diff --git a/src/python/doc/cubical_complex_user.rst b/src/python/doc/cubical_complex_user.rst
index 6a211347..42a23875 100644
--- a/src/python/doc/cubical_complex_user.rst
+++ b/src/python/doc/cubical_complex_user.rst
@@ -7,14 +7,7 @@ Cubical complex user manual
Definition
----------
-===================================== ===================================== =====================================
-:Author: Pawel Dlotko :Since: GUDHI PYTHON 2.0.0 :License: GPL v3
-===================================== ===================================== =====================================
-
-+---------------------------------------------+----------------------------------------------------------------------+
-| :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
-| | * :doc:`periodic_cubical_complex_ref` |
-+---------------------------------------------+----------------------------------------------------------------------+
+.. include:: cubical_complex_sum.inc
The cubical complex is an example of a structured complex useful in computational mathematics (specially rigorous
numerics) and image analysis.
@@ -163,4 +156,4 @@ Tutorial
--------
This `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-cubical-complexes.ipynb>`_
-explains how to represent sublevels sets of functions using cubical complexes. \ No newline at end of file
+explains how to represent sublevels sets of functions using cubical complexes.
diff --git a/src/python/doc/differentiation_sum.inc b/src/python/doc/differentiation_sum.inc
index 3aec33df..140cf180 100644
--- a/src/python/doc/differentiation_sum.inc
+++ b/src/python/doc/differentiation_sum.inc
@@ -1,8 +1,8 @@
.. list-table::
- :widths: 40 30 30
+ :width: 100%
:header-rows: 0
- * - :Since: GUDHI 3.5.0
+ * - :Since: GUDHI 3.6.0
- :License: MIT
- :Requires: `TensorFlow <installation.html#tensorflow>`_
diff --git a/src/python/doc/img/sklearn.png b/src/python/doc/img/sklearn.png
new file mode 100644
index 00000000..d1fecbbf
--- /dev/null
+++ b/src/python/doc/img/sklearn.png
Binary files differ
diff --git a/src/python/doc/img/tensorflow.png b/src/python/doc/img/tensorflow.png
new file mode 100644
index 00000000..a75f3f5b
--- /dev/null
+++ b/src/python/doc/img/tensorflow.png
Binary files differ
diff --git a/src/python/doc/persistent_cohomology_user.rst b/src/python/doc/persistent_cohomology_user.rst
index a3f294b2..39744b95 100644
--- a/src/python/doc/persistent_cohomology_user.rst
+++ b/src/python/doc/persistent_cohomology_user.rst
@@ -6,19 +6,24 @@ Persistent cohomology user manual
=================================
Definition
----------
-===================================== ===================================== =====================================
-:Author: Clément Maria :Since: GUDHI PYTHON 2.0.0 :License: GPL v3
-===================================== ===================================== =====================================
-
-+-----------------------------------------------------------------+-----------------------------------------------------------------------+
-| :doc:`persistent_cohomology_user` | Please refer to each data structure that contains persistence |
-| | feature for reference: |
-| | |
-| | * :doc:`simplex_tree_ref` |
-| | * :doc:`cubical_complex_ref` |
-| | * :doc:`periodic_cubical_complex_ref` |
-+-----------------------------------------------------------------+-----------------------------------------------------------------------+
+.. list-table::
+ :width: 100%
+ :header-rows: 0
+
+ * - :Author: Clément Maria
+ - :Since: GUDHI 2.0.0
+ - :License: MIT
+
+.. list-table::
+ :width: 100%
+ :header-rows: 0
+
+ * - :doc:`persistent_cohomology_user`
+ - Please refer to each data structure that contains persistence feature for reference:
+ * :doc:`simplex_tree_ref`
+ * :doc:`cubical_complex_ref`
+ * :doc:`periodic_cubical_complex_ref`
Computation of persistent cohomology using the algorithm of :cite:`DBLP:journals/dcg/SilvaMV11` and
:cite:`DBLP:conf/compgeom/DeyFW14` and the Compressed Annotation Matrix implementation of
diff --git a/src/python/doc/rips_complex_sum.inc b/src/python/doc/rips_complex_sum.inc
index 6931ebee..2b125e54 100644
--- a/src/python/doc/rips_complex_sum.inc
+++ b/src/python/doc/rips_complex_sum.inc
@@ -12,6 +12,8 @@
+----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
| * :doc:`rips_complex_user` | * :doc:`rips_complex_ref` |
+----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
- | | * :doc:`rips_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
- | | | |
+ | .. image:: | * :doc:`rips_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
+ | img/tensorflow.png | | |
+ | :target: https://www.tensorflow.org | | |
+ | :height: 30 | | |
+----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
diff --git a/src/python/doc/rips_complex_user.rst b/src/python/doc/rips_complex_user.rst
index 27d218d4..c41a7803 100644
--- a/src/python/doc/rips_complex_user.rst
+++ b/src/python/doc/rips_complex_user.rst
@@ -7,13 +7,7 @@ Rips complex user manual
Definition
----------
-================================================================================ ================================ ======================
-:Authors: Clément Maria, Pawel Dlotko, Vincent Rouvreau, Marc Glisse, Yuichi Ike :Since: GUDHI 2.0.0 :License: GPL v3
-================================================================================ ================================ ======================
-
-+-------------------------------------------+----------------------------------------------------------------------+
-| :doc:`rips_complex_user` | :doc:`rips_complex_ref` |
-+-------------------------------------------+----------------------------------------------------------------------+
+.. include:: rips_complex_sum.inc
The `Rips complex <https://en.wikipedia.org/wiki/Vietoris%E2%80%93Rips_complex>`_ is a simplicial complex that
generalizes proximity (:math:`\varepsilon`-ball) graphs to higher dimensions. The vertices correspond to the input
diff --git a/src/python/doc/simplex_tree_sum.inc b/src/python/doc/simplex_tree_sum.inc
index 3ad1292c..6b534c9e 100644
--- a/src/python/doc/simplex_tree_sum.inc
+++ b/src/python/doc/simplex_tree_sum.inc
@@ -11,6 +11,8 @@
+----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
| * :doc:`simplex_tree_user` | * :doc:`simplex_tree_ref` |
+----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
- | | * :doc:`ls_simplex_tree_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
- | | | |
+ | .. image:: | * :doc:`ls_simplex_tree_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
+ | img/tensorflow.png | | |
+ | :target: https://www.tensorflow.org | | |
+ | :height: 30 | | |
+----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
diff --git a/src/python/gudhi/representations/preprocessing.py b/src/python/gudhi/representations/preprocessing.py
index a8545349..8722e162 100644
--- a/src/python/gudhi/representations/preprocessing.py
+++ b/src/python/gudhi/representations/preprocessing.py
@@ -1,10 +1,11 @@
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
-# Author(s): Mathieu Carrière
+# Author(s): Mathieu Carrière, Vincent Rouvreau
#
# Copyright (C) 2018-2019 Inria
#
# Modification(s):
+# - 2021/10 Vincent Rouvreau: Add DimensionSelector
# - YYYY/MM Author: Description of the modification
import numpy as np
@@ -75,7 +76,7 @@ class Clamping(BaseEstimator, TransformerMixin):
Constructor for the Clamping class.
Parameters:
- limit (double): clamping value (default np.inf).
+ limit (float): clamping value (default np.inf).
"""
self.minimum = minimum
self.maximum = maximum
@@ -234,7 +235,7 @@ class ProminentPoints(BaseEstimator, TransformerMixin):
use (bool): whether to use the class or not (default False).
location (string): either "upper" or "lower" (default "upper"). Whether to keep the points that are far away ("upper") or close ("lower") to the diagonal.
num_pts (int): cardinality threshold (default 10). If location == "upper", keep the top **num_pts** points that are the farthest away from the diagonal. If location == "lower", keep the top **num_pts** points that are the closest to the diagonal.
- threshold (double): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal.
+ threshold (float): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal.
"""
self.num_pts = num_pts
self.threshold = threshold
@@ -317,7 +318,7 @@ class DiagramSelector(BaseEstimator, TransformerMixin):
Parameters:
use (bool): whether to use the class or not (default False).
- limit (double): second coordinate value that is the criterion for being an essential point (default numpy.inf).
+ limit (float): second coordinate value that is the criterion for being an essential point (default numpy.inf).
point_type (string): either "finite" or "essential". The type of the points that are going to be extracted.
"""
self.use, self.limit, self.point_type = use, limit, point_type
@@ -363,3 +364,51 @@ class DiagramSelector(BaseEstimator, TransformerMixin):
n x 2 numpy array: extracted persistence diagram.
"""
return self.fit_transform([diag])[0]
+
+
+# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
+# sequenceDiagram
+# USER->>DimensionSelector: fit_transform(<br/>[[array( Hi(X0) ), array( Hj(X0) ), ...],<br/> [array( Hi(X1) ), array( Hj(X1) ), ...],<br/> ...])
+# DimensionSelector->>thread1: _transform([array( Hi(X0) ), array( Hj(X0) )], ...)
+# DimensionSelector->>thread2: _transform([array( Hi(X1) ), array( Hj(X1) )], ...)
+# Note right of DimensionSelector: ...
+# thread1->>DimensionSelector: array( Hn(X0) )
+# thread2->>DimensionSelector: array( Hn(X1) )
+# Note right of DimensionSelector: ...
+# DimensionSelector->>USER: [array( Hn(X0) ), <br/> array( Hn(X1) ), <br/> ...]
+
+class DimensionSelector(BaseEstimator, TransformerMixin):
+ """
+ This is a class to select persistence diagrams in a specific dimension from its index.
+ """
+
+ def __init__(self, index=0):
+ """
+ Constructor for the DimensionSelector class.
+
+ Parameters:
+ index (int): The returned persistence diagrams dimension index. Default value is `0`.
+ """
+ self.index = index
+
+ def fit(self, X, Y=None):
+ """
+ Nothing to be done, but useful when included in a scikit-learn Pipeline.
+ """
+ return self
+
+ def transform(self, X, Y=None):
+ """
+ Select persistence diagrams from its dimension.
+
+ Parameters:
+ X (list of list of tuple): List of list of persistence pairs, i.e.
+ `[[array( Hi(X0) ), array( Hj(X0) ), ...], [array( Hi(X1) ), array( Hj(X1) ), ...], ...]`
+
+ Returns:
+ list of tuple:
+ Persistence diagrams in a specific dimension. i.e. if `index` was set to `m` and `Hn` is at index `m` of
+ the input, it returns `[array( Hn(X0) ), array( Hn(X1), ...]`
+ """
+
+ return [persistence[self.index] for persistence in X]
diff --git a/src/python/gudhi/sklearn/__init__.py b/src/python/gudhi/sklearn/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/python/gudhi/sklearn/__init__.py
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py
new file mode 100644
index 00000000..672af278
--- /dev/null
+++ b/src/python/gudhi/sklearn/cubical_persistence.py
@@ -0,0 +1,110 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Vincent Rouvreau
+#
+# Copyright (C) 2021 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+from .. import CubicalComplex
+from sklearn.base import BaseEstimator, TransformerMixin
+
+import numpy as np
+# joblib is required by scikit-learn
+from joblib import Parallel, delayed
+
+# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
+# sequenceDiagram
+# USER->>CubicalPersistence: fit_transform(X)
+# CubicalPersistence->>thread1: _tranform(X[0])
+# CubicalPersistence->>thread2: _tranform(X[1])
+# Note right of CubicalPersistence: ...
+# thread1->>CubicalPersistence: [array( H0(X[0]) ), array( H1(X[0]) )]
+# thread2->>CubicalPersistence: [array( H0(X[1]) ), array( H1(X[1]) )]
+# Note right of CubicalPersistence: ...
+# CubicalPersistence->>USER: [[array( H0(X[0]) ), array( H1(X[0]) )],<br/> [array( H0(X[1]) ), array( H1(X[1]) )],<br/> ...]
+
+
+class CubicalPersistence(BaseEstimator, TransformerMixin):
+ """
+ This is a class for computing the persistence diagrams from a cubical complex.
+ """
+
+ def __init__(
+ self,
+ homology_dimensions,
+ newshape=None,
+ homology_coeff_field=11,
+ min_persistence=0.0,
+ n_jobs=None,
+ ):
+ """
+ Constructor for the CubicalPersistence class.
+
+ Parameters:
+ homology_dimensions (int or list of int): The returned persistence diagrams dimension(s).
+ Short circuit the use of :class:`~gudhi.representations.preprocessing.DimensionSelector` when only one
+ dimension matters (in other words, when `homology_dimensions` is an int).
+ newshape (tuple of ints): If cells filtration values require to be reshaped
+ (cf. :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform`), set `newshape`
+ to perform `numpy.reshape(X, newshape, order='C')` in
+ :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform` method.
+ homology_coeff_field (int): The homology coefficient field. Must be a prime number. Default value is 11.
+ min_persistence (float): The minimum persistence value to take into account (strictly greater than
+ `min_persistence`). Default value is `0.0`. Set `min_persistence` to `-1.0` to see all values.
+ n_jobs (int): cf. https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html
+ """
+ self.homology_dimensions = homology_dimensions
+ self.newshape = newshape
+ self.homology_coeff_field = homology_coeff_field
+ self.min_persistence = min_persistence
+ self.n_jobs = n_jobs
+
+ def fit(self, X, Y=None):
+ """
+ Nothing to be done, but useful when included in a scikit-learn Pipeline.
+ """
+ return self
+
+ def __transform(self, cells):
+ cubical_complex = CubicalComplex(top_dimensional_cells=cells)
+ cubical_complex.compute_persistence(
+ homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence
+ )
+ return [
+ cubical_complex.persistence_intervals_in_dimension(dim) for dim in self.homology_dimensions
+ ]
+
+ def __transform_only_this_dim(self, cells):
+ cubical_complex = CubicalComplex(top_dimensional_cells=cells)
+ cubical_complex.compute_persistence(
+ homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence
+ )
+ return cubical_complex.persistence_intervals_in_dimension(self.homology_dimensions)
+
+ def transform(self, X, Y=None):
+ """Compute all the cubical complexes and their associated persistence diagrams.
+
+ :param X: List of cells filtration values (`numpy.reshape(X, newshape, order='C'` if `newshape` is set with a tuple of ints).
+ :type X: list of list of float OR list of numpy.ndarray
+
+ :return: Persistence diagrams in the format:
+
+ - If `homology_dimensions` was set to `n`: `[array( Hn(X[0]) ), array( Hn(X[1]) ), ...]`
+ - If `homology_dimensions` was set to `[i, j]`: `[[array( Hi(X[0]) ), array( Hj(X[0]) )], [array( Hi(X[1]) ), array( Hj(X[1]) )], ...]`
+ :rtype: list of (,2) array_like or list of list of (,2) array_like
+ """
+ if self.newshape is not None:
+ X = np.reshape(X, self.newshape, order='C')
+
+ # Depends on homology_dimensions is an integer or a list of integer (else case)
+ if isinstance(self.homology_dimensions, int):
+ # threads is preferred as cubical construction and persistence computation releases the GIL
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(
+ delayed(self.__transform_only_this_dim)(cells) for cells in X
+ )
+ else:
+ # threads is preferred as cubical construction and persistence computation releases the GIL
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(cells) for cells in X)
+
diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py
index 3304e719..5df2c370 100644
--- a/src/python/gudhi/tensorflow/cubical_layer.py
+++ b/src/python/gudhi/tensorflow/cubical_layer.py
@@ -18,7 +18,7 @@ def _Cubical(Xflat, Xdim, dimensions, homology_coeff_field):
cc = CubicalComplex(dimensions=Xdim[::-1], top_dimensional_cells=Xflat)
cc.compute_persistence(homology_coeff_field=homology_coeff_field)
- # Retrieve and ouput image indices/pixels corresponding to positive and negative simplices
+ # Retrieve and output image indices/pixels corresponding to positive and negative simplices
cof_pp = cc.cofaces_of_persistence_pairs()
L_cofs = []
diff --git a/src/python/test/test_representations_preprocessing.py b/src/python/test/test_representations_preprocessing.py
new file mode 100644
index 00000000..838cf30c
--- /dev/null
+++ b/src/python/test/test_representations_preprocessing.py
@@ -0,0 +1,39 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.representations.preprocessing import DimensionSelector
+import numpy as np
+import pytest
+
+H0_0 = np.array([0.0, 0.0])
+H1_0 = np.array([1.0, 0.0])
+H0_1 = np.array([0.0, 1.0])
+H1_1 = np.array([1.0, 1.0])
+H0_2 = np.array([0.0, 2.0])
+H1_2 = np.array([1.0, 2.0])
+
+
+def test_dimension_selector():
+ X = [[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]]
+ ds = DimensionSelector(index=0)
+ h0 = ds.fit_transform(X)
+ np.testing.assert_array_equal(h0[0], H0_0)
+ np.testing.assert_array_equal(h0[1], H0_1)
+ np.testing.assert_array_equal(h0[2], H0_2)
+
+ ds = DimensionSelector(index=1)
+ h1 = ds.fit_transform(X)
+ np.testing.assert_array_equal(h1[0], H1_0)
+ np.testing.assert_array_equal(h1[1], H1_1)
+ np.testing.assert_array_equal(h1[2], H1_2)
+
+ ds = DimensionSelector(index=2)
+ with pytest.raises(IndexError):
+ h2 = ds.fit_transform([[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]])
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
new file mode 100644
index 00000000..1c05a215
--- /dev/null
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -0,0 +1,59 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.sklearn.cubical_persistence import CubicalPersistence
+import numpy as np
+from sklearn import datasets
+
+CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0.0, 6.0], [0.0, 8.0], [0.0, np.inf]])
+
+
+def test_simple_constructor_from_top_cells():
+ cells = datasets.load_digits().images[0]
+ cp = CubicalPersistence(homology_dimensions=0)
+ np.testing.assert_array_equal(cp._CubicalPersistence__transform_only_this_dim(cells), CUBICAL_PERSISTENCE_H0_IMG0)
+ cp = CubicalPersistence(homology_dimensions=[0, 2])
+ diags = cp._CubicalPersistence__transform(cells)
+ assert len(diags) == 2
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+
+def test_simple_constructor_from_top_cells_list():
+ digits = datasets.load_digits().images[:10]
+ cp = CubicalPersistence(homology_dimensions=0, n_jobs=-2)
+
+ diags = cp.fit_transform(digits)
+ assert len(diags) == 10
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+ cp = CubicalPersistence(homology_dimensions=[0, 1], n_jobs=-1)
+ diagsH0H1 = cp.fit_transform(digits)
+ assert len(diagsH0H1) == 10
+ for idx in range(10):
+ np.testing.assert_array_equal(diags[idx], diagsH0H1[idx][0])
+
+def test_simple_constructor_from_flattened_cells():
+ cells = datasets.load_digits().images[0]
+ # Not squared (extended) flatten cells
+ flat_cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten()
+
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
+ diags = cp.fit_transform([flat_cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+ # Not squared (extended) non-flatten cells
+ cells = np.hstack((cells, np.zeros((cells.shape[0], 2))))
+
+ # The aim of this second part of the test is to resize even if not mandatory
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
+ diags = cp.fit_transform([cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)