diff options
m--------- | ext/gudhi-deploy | 0 | ||||
-rw-r--r-- | src/Cech_complex/include/gudhi/Cech_complex_blocker.h | 3 | ||||
-rw-r--r-- | src/Contraction/include/gudhi/Skeleton_blocker_contractor.h | 21 | ||||
-rw-r--r-- | src/python/CMakeLists.txt | 11 | ||||
-rw-r--r-- | src/python/doc/cubical_complex_sklearn_itf_ref.rst | 102 | ||||
-rw-r--r-- | src/python/doc/cubical_complex_sum.inc | 31 | ||||
-rw-r--r-- | src/python/doc/cubical_complex_user.rst | 11 | ||||
-rw-r--r-- | src/python/doc/img/sklearn.png | bin | 0 -> 9368 bytes | |||
-rw-r--r-- | src/python/gudhi/representations/preprocessing.py | 57 | ||||
-rw-r--r-- | src/python/gudhi/representations/vector_methods.py | 18 | ||||
-rw-r--r-- | src/python/gudhi/sklearn/__init__.py | 0 | ||||
-rw-r--r-- | src/python/gudhi/sklearn/cubical_persistence.py | 110 | ||||
-rwxr-xr-x | src/python/test/test_representations.py | 21 | ||||
-rw-r--r-- | src/python/test/test_representations_preprocessing.py | 39 | ||||
-rw-r--r-- | src/python/test/test_sklearn_cubical_persistence.py | 59 |
15 files changed, 437 insertions, 46 deletions
diff --git a/ext/gudhi-deploy b/ext/gudhi-deploy -Subproject e9e9a4878731853d2d3149a5eac30df338a8197 +Subproject 290ade1086bedbc96a35df886cadecabbf4072e diff --git a/src/Cech_complex/include/gudhi/Cech_complex_blocker.h b/src/Cech_complex/include/gudhi/Cech_complex_blocker.h index 62b0d25c..e78e37b7 100644 --- a/src/Cech_complex/include/gudhi/Cech_complex_blocker.h +++ b/src/Cech_complex/include/gudhi/Cech_complex_blocker.h @@ -128,7 +128,8 @@ class Cech_blocker { #ifdef DEBUG_TRACES if (radius > cc_ptr_->max_radius()) std::clog << "radius > max_radius => expansion is blocked\n"; #endif // DEBUG_TRACES - sc_ptr_->assign_filtration(sh, radius); + // Check that the filtration to be assigned (radius) would be valid + if (radius > sc_ptr_->filtration(sh)) sc_ptr_->assign_filtration(sh, radius); return (radius > cc_ptr_->max_radius()); } diff --git a/src/Contraction/include/gudhi/Skeleton_blocker_contractor.h b/src/Contraction/include/gudhi/Skeleton_blocker_contractor.h index 56b76318..6911ca2e 100644 --- a/src/Contraction/include/gudhi/Skeleton_blocker_contractor.h +++ b/src/Contraction/include/gudhi/Skeleton_blocker_contractor.h @@ -171,8 +171,13 @@ typename GeometricSimplifiableComplex::Vertex_handle> { Self const* algorithm_; }; +#if CGAL_VERSION_NR < 1050500000 typedef CGAL::Modifiable_priority_queue<Edge_handle, Compare_cost, Undirected_edge_id> PQ; - typedef typename PQ::handle pq_handle; +#else + typedef CGAL::Modifiable_priority_queue<Edge_handle, Compare_cost, Undirected_edge_id, CGAL::CGAL_BOOST_PENDING_RELAXED_HEAP> PQ; +#endif + + typedef bool pq_handle; // An Edge_data is associated with EVERY edge in the complex (collapsible or not). @@ -196,7 +201,7 @@ typename GeometricSimplifiableComplex::Vertex_handle> { } bool is_in_PQ() const { - return PQHandle_ != PQ::null_handle(); + return PQHandle_ != false; } void set_PQ_handle(pq_handle h) { @@ -204,7 +209,7 @@ typename GeometricSimplifiableComplex::Vertex_handle> { } void reset_PQ_handle() { - PQHandle_ = PQ::null_handle(); + PQHandle_ = false; } private: @@ -238,16 +243,22 @@ typename GeometricSimplifiableComplex::Vertex_handle> { } void insert_in_PQ(Edge_handle edge, Edge_data& data) { - data.set_PQ_handle(heap_PQ_->push(edge)); + heap_PQ_->push(edge); + data.set_PQ_handle(true); ++current_num_edges_heap_; } void update_in_PQ(Edge_handle edge, Edge_data& data) { +#if CGAL_VERSION_NR < 1050500000 data.set_PQ_handle(heap_PQ_->update(edge, data.PQ_handle())); +#else + heap_PQ_->update(edge); +#endif } void remove_from_PQ(Edge_handle edge, Edge_data& data) { - data.set_PQ_handle(heap_PQ_->erase(edge, data.PQ_handle())); + heap_PQ_->erase(edge); + data.set_PQ_handle(false); --current_num_edges_heap_; } diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 65adef75..5f323935 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -298,6 +298,7 @@ if(PYTHONINTERP_FOUND) file(COPY "gudhi/dtm_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") file(COPY "gudhi/hera/__init__.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/hera") file(COPY "gudhi/datasets" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi" FILES_MATCHING PATTERN "*.py") + file(COPY "gudhi/sklearn" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/") # Some files for pip package @@ -574,6 +575,11 @@ if(PYTHONINTERP_FOUND) add_gudhi_py_test(test_betti_curve_representations) endif() + # Representations preprocessing + if(SKLEARN_FOUND) + add_gudhi_py_test(test_representations_preprocessing) + endif() + # Time Delay add_gudhi_py_test(test_time_delay) @@ -603,6 +609,11 @@ if(PYTHONINTERP_FOUND) add_gudhi_py_test(test_remote_datasets) endif() + # sklearn + if(SKLEARN_FOUND) + add_gudhi_py_test(test_sklearn_cubical_persistence) + endif() + # persistence graphical tools if(MATPLOTLIB_FOUND) add_gudhi_py_test(test_persistence_graphical_tools) diff --git a/src/python/doc/cubical_complex_sklearn_itf_ref.rst b/src/python/doc/cubical_complex_sklearn_itf_ref.rst new file mode 100644 index 00000000..05ffdd0c --- /dev/null +++ b/src/python/doc/cubical_complex_sklearn_itf_ref.rst @@ -0,0 +1,102 @@ +:orphan: + +.. To get rid of WARNING: document isn't included in any toctree + +Cubical complex persistence scikit-learn like interface +####################################################### + +.. list-table:: + :width: 100% + :header-rows: 0 + + * - :Since: GUDHI 3.5.0 + - :License: MIT + - :Requires: `Scikit-learn <installation.html#scikit-learn>`_ + +Cubical complex persistence scikit-learn like interface example +--------------------------------------------------------------- + +In this example, hand written digits are used as an input. +a TDA scikit-learn pipeline is constructed and is composed of: + +#. :class:`~gudhi.sklearn.cubical_persistence.CubicalPersistence` that builds a cubical complex from the inputs and + returns its persistence diagrams +#. :class:`~gudhi.representations.preprocessing.DiagramSelector` that removes non-finite persistence diagrams values +#. :class:`~gudhi.representations.vector_methods.PersistenceImage` that builds the persistence images from persistence diagrams +#. `SVC <https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html>`_ which is a scikit-learn support + vector classifier. + +This ML pipeline is trained to detect if the hand written digit is an '8' or not, thanks to the fact that an '8' has +two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected components in :math:`\mathbf{H}_0`. + +.. code-block:: python + + # Standard scientific Python imports + import numpy as np + + # Standard scikit-learn imports + from sklearn.datasets import fetch_openml + from sklearn.pipeline import Pipeline + from sklearn.model_selection import train_test_split + from sklearn.svm import SVC + from sklearn import metrics + + # Import TDA pipeline requirements + from gudhi.sklearn.cubical_persistence import CubicalPersistence + from gudhi.representations import PersistenceImage, DiagramSelector + + X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False) + + # Target is: "is an eight ?" + y = (y == "8") * 1 + print("There are", np.sum(y), "eights out of", len(y), "numbers.") + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0) + pipe = Pipeline( + [ + ("cub_pers", CubicalPersistence(homology_dimensions=0, newshape=[28, 28], n_jobs=-2)), + # Or for multiple persistence dimension computation + # ("cub_pers", CubicalPersistence(homology_dimensions=[0, 1], newshape=[28, 28], n_jobs=-2)), + # ("H0_diags", DimensionSelector(index=0), # where index is the index in homology_dimensions array + ("finite_diags", DiagramSelector(use=True, point_type="finite")), + ( + "pers_img", + PersistenceImage(bandwidth=50, weight=lambda x: x[1] ** 2, im_range=[0, 256, 0, 256], resolution=[20, 20]), + ), + ("svc", SVC()), + ] + ) + + # Learn from the train subset + pipe.fit(X_train, y_train) + # Predict from the test subset + predicted = pipe.predict(X_test) + + print(f"Classification report for TDA pipeline {pipe}:\n" f"{metrics.classification_report(y_test, predicted)}\n") + +.. code-block:: none + + There are 6825 eights out of 70000 numbers. + Classification report for TDA pipeline Pipeline(steps=[('cub_pers', + CubicalPersistence(newshape=[28, 28], n_jobs=-2)), + ('finite_diags', DiagramSelector(use=True)), + ('pers_img', + PersistenceImage(bandwidth=50, im_range=[0, 256, 0, 256], + weight=<function <lambda> at 0x7f3e54137ae8>)), + ('svc', SVC())]): + precision recall f1-score support + + 0 0.97 0.99 0.98 25284 + 1 0.92 0.68 0.78 2716 + + accuracy 0.96 28000 + macro avg 0.94 0.84 0.88 28000 + weighted avg 0.96 0.96 0.96 28000 + +Cubical complex persistence scikit-learn like interface reference +----------------------------------------------------------------- + +.. autoclass:: gudhi.sklearn.cubical_persistence.CubicalPersistence + :members: + :special-members: __init__ + :show-inheritance:
\ No newline at end of file diff --git a/src/python/doc/cubical_complex_sum.inc b/src/python/doc/cubical_complex_sum.inc index 90ec9fc2..f1cf25d4 100644 --- a/src/python/doc/cubical_complex_sum.inc +++ b/src/python/doc/cubical_complex_sum.inc @@ -1,17 +1,20 @@ .. table:: :widths: 30 40 30 - +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+ - | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko | - | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | | - | :alt: Cubical complex representation | | :Since: GUDHI 2.0.0 | - | :figclass: align-center | | | - | | | :License: MIT | - | | | | - +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+ - | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` | - | | * :doc:`periodic_cubical_complex_ref` | - +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+ - | | * :doc:`cubical_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ | - | | | | - +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+ + +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+ + | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko | + | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | :Since: GUDHI 2.0.0 | + | :alt: Cubical complex representation | | :License: MIT | + | :figclass: align-center | | | + +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+ + | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` | + | | * :doc:`periodic_cubical_complex_ref` | + +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+ + | | * :doc:`cubical_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ | + | | | | + +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+ + | .. image:: | * :doc:`cubical_complex_sklearn_itf_ref` | :Requires: `Scikit-learn <installation.html#scikit-learn>`_ | + | img/sklearn.png | | | + | :target: https://scikit-learn.org | | | + | :height: 30 | | | + +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+ diff --git a/src/python/doc/cubical_complex_user.rst b/src/python/doc/cubical_complex_user.rst index 6a211347..42a23875 100644 --- a/src/python/doc/cubical_complex_user.rst +++ b/src/python/doc/cubical_complex_user.rst @@ -7,14 +7,7 @@ Cubical complex user manual Definition ---------- -===================================== ===================================== ===================================== -:Author: Pawel Dlotko :Since: GUDHI PYTHON 2.0.0 :License: GPL v3 -===================================== ===================================== ===================================== - -+---------------------------------------------+----------------------------------------------------------------------+ -| :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` | -| | * :doc:`periodic_cubical_complex_ref` | -+---------------------------------------------+----------------------------------------------------------------------+ +.. include:: cubical_complex_sum.inc The cubical complex is an example of a structured complex useful in computational mathematics (specially rigorous numerics) and image analysis. @@ -163,4 +156,4 @@ Tutorial -------- This `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-cubical-complexes.ipynb>`_ -explains how to represent sublevels sets of functions using cubical complexes.
\ No newline at end of file +explains how to represent sublevels sets of functions using cubical complexes. diff --git a/src/python/doc/img/sklearn.png b/src/python/doc/img/sklearn.png Binary files differnew file mode 100644 index 00000000..d1fecbbf --- /dev/null +++ b/src/python/doc/img/sklearn.png diff --git a/src/python/gudhi/representations/preprocessing.py b/src/python/gudhi/representations/preprocessing.py index a8545349..8722e162 100644 --- a/src/python/gudhi/representations/preprocessing.py +++ b/src/python/gudhi/representations/preprocessing.py @@ -1,10 +1,11 @@ # This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. # See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. -# Author(s): Mathieu Carrière +# Author(s): Mathieu Carrière, Vincent Rouvreau # # Copyright (C) 2018-2019 Inria # # Modification(s): +# - 2021/10 Vincent Rouvreau: Add DimensionSelector # - YYYY/MM Author: Description of the modification import numpy as np @@ -75,7 +76,7 @@ class Clamping(BaseEstimator, TransformerMixin): Constructor for the Clamping class. Parameters: - limit (double): clamping value (default np.inf). + limit (float): clamping value (default np.inf). """ self.minimum = minimum self.maximum = maximum @@ -234,7 +235,7 @@ class ProminentPoints(BaseEstimator, TransformerMixin): use (bool): whether to use the class or not (default False). location (string): either "upper" or "lower" (default "upper"). Whether to keep the points that are far away ("upper") or close ("lower") to the diagonal. num_pts (int): cardinality threshold (default 10). If location == "upper", keep the top **num_pts** points that are the farthest away from the diagonal. If location == "lower", keep the top **num_pts** points that are the closest to the diagonal. - threshold (double): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal. + threshold (float): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal. """ self.num_pts = num_pts self.threshold = threshold @@ -317,7 +318,7 @@ class DiagramSelector(BaseEstimator, TransformerMixin): Parameters: use (bool): whether to use the class or not (default False). - limit (double): second coordinate value that is the criterion for being an essential point (default numpy.inf). + limit (float): second coordinate value that is the criterion for being an essential point (default numpy.inf). point_type (string): either "finite" or "essential". The type of the points that are going to be extracted. """ self.use, self.limit, self.point_type = use, limit, point_type @@ -363,3 +364,51 @@ class DiagramSelector(BaseEstimator, TransformerMixin): n x 2 numpy array: extracted persistence diagram. """ return self.fit_transform([diag])[0] + + +# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/ +# sequenceDiagram +# USER->>DimensionSelector: fit_transform(<br/>[[array( Hi(X0) ), array( Hj(X0) ), ...],<br/> [array( Hi(X1) ), array( Hj(X1) ), ...],<br/> ...]) +# DimensionSelector->>thread1: _transform([array( Hi(X0) ), array( Hj(X0) )], ...) +# DimensionSelector->>thread2: _transform([array( Hi(X1) ), array( Hj(X1) )], ...) +# Note right of DimensionSelector: ... +# thread1->>DimensionSelector: array( Hn(X0) ) +# thread2->>DimensionSelector: array( Hn(X1) ) +# Note right of DimensionSelector: ... +# DimensionSelector->>USER: [array( Hn(X0) ), <br/> array( Hn(X1) ), <br/> ...] + +class DimensionSelector(BaseEstimator, TransformerMixin): + """ + This is a class to select persistence diagrams in a specific dimension from its index. + """ + + def __init__(self, index=0): + """ + Constructor for the DimensionSelector class. + + Parameters: + index (int): The returned persistence diagrams dimension index. Default value is `0`. + """ + self.index = index + + def fit(self, X, Y=None): + """ + Nothing to be done, but useful when included in a scikit-learn Pipeline. + """ + return self + + def transform(self, X, Y=None): + """ + Select persistence diagrams from its dimension. + + Parameters: + X (list of list of tuple): List of list of persistence pairs, i.e. + `[[array( Hi(X0) ), array( Hj(X0) ), ...], [array( Hi(X1) ), array( Hj(X1) ), ...], ...]` + + Returns: + list of tuple: + Persistence diagrams in a specific dimension. i.e. if `index` was set to `m` and `Hn` is at index `m` of + the input, it returns `[array( Hn(X0) ), array( Hn(X1), ...]` + """ + + return [persistence[self.index] for persistence in X] diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py index f8078d03..69ff5e1e 100644 --- a/src/python/gudhi/representations/vector_methods.py +++ b/src/python/gudhi/representations/vector_methods.py @@ -508,26 +508,20 @@ class Entropy(BaseEstimator, TransformerMixin): new_X = BirthPersistenceTransform().fit_transform(X) for i in range(num_diag): - orig_diagram, diagram, num_pts_in_diag = X[i], new_X[i], X[i].shape[0] - try: - new_diagram = DiagramScaler(use=True, scalers=[([1], MaxAbsScaler())]).fit_transform([diagram])[0] - except ValueError: - # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507 - assert len(diagram) == 0 - new_diagram = np.empty(shape = [0, 2]) - + orig_diagram, new_diagram, num_pts_in_diag = X[i], new_X[i], X[i].shape[0] + + p = new_diagram[:,1] + p = p/np.sum(p) if self.mode == "scalar": - ent = - np.sum( np.multiply(new_diagram[:,1], np.log(new_diagram[:,1])) ) + ent = -np.dot(p, np.log(p)) Xfit.append(np.array([[ent]])) - else: ent = np.zeros(self.resolution) for j in range(num_pts_in_diag): [px,py] = orig_diagram[j,:2] min_idx = np.clip(np.ceil((px - self.sample_range[0]) / step_x).astype(int), 0, self.resolution) max_idx = np.clip(np.ceil((py - self.sample_range[0]) / step_x).astype(int), 0, self.resolution) - for k in range(min_idx, max_idx): - ent[k] += (-1) * new_diagram[j,1] * np.log(new_diagram[j,1]) + ent[min_idx:max_idx]-=p[j]*np.log(p[j]) if self.normalized: ent = ent / np.linalg.norm(ent, ord=1) Xfit.append(np.reshape(ent,[1,-1])) diff --git a/src/python/gudhi/sklearn/__init__.py b/src/python/gudhi/sklearn/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/python/gudhi/sklearn/__init__.py diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py new file mode 100644 index 00000000..672af278 --- /dev/null +++ b/src/python/gudhi/sklearn/cubical_persistence.py @@ -0,0 +1,110 @@ +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# Author(s): Vincent Rouvreau +# +# Copyright (C) 2021 Inria +# +# Modification(s): +# - YYYY/MM Author: Description of the modification + +from .. import CubicalComplex +from sklearn.base import BaseEstimator, TransformerMixin + +import numpy as np +# joblib is required by scikit-learn +from joblib import Parallel, delayed + +# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/ +# sequenceDiagram +# USER->>CubicalPersistence: fit_transform(X) +# CubicalPersistence->>thread1: _tranform(X[0]) +# CubicalPersistence->>thread2: _tranform(X[1]) +# Note right of CubicalPersistence: ... +# thread1->>CubicalPersistence: [array( H0(X[0]) ), array( H1(X[0]) )] +# thread2->>CubicalPersistence: [array( H0(X[1]) ), array( H1(X[1]) )] +# Note right of CubicalPersistence: ... +# CubicalPersistence->>USER: [[array( H0(X[0]) ), array( H1(X[0]) )],<br/> [array( H0(X[1]) ), array( H1(X[1]) )],<br/> ...] + + +class CubicalPersistence(BaseEstimator, TransformerMixin): + """ + This is a class for computing the persistence diagrams from a cubical complex. + """ + + def __init__( + self, + homology_dimensions, + newshape=None, + homology_coeff_field=11, + min_persistence=0.0, + n_jobs=None, + ): + """ + Constructor for the CubicalPersistence class. + + Parameters: + homology_dimensions (int or list of int): The returned persistence diagrams dimension(s). + Short circuit the use of :class:`~gudhi.representations.preprocessing.DimensionSelector` when only one + dimension matters (in other words, when `homology_dimensions` is an int). + newshape (tuple of ints): If cells filtration values require to be reshaped + (cf. :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform`), set `newshape` + to perform `numpy.reshape(X, newshape, order='C')` in + :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform` method. + homology_coeff_field (int): The homology coefficient field. Must be a prime number. Default value is 11. + min_persistence (float): The minimum persistence value to take into account (strictly greater than + `min_persistence`). Default value is `0.0`. Set `min_persistence` to `-1.0` to see all values. + n_jobs (int): cf. https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html + """ + self.homology_dimensions = homology_dimensions + self.newshape = newshape + self.homology_coeff_field = homology_coeff_field + self.min_persistence = min_persistence + self.n_jobs = n_jobs + + def fit(self, X, Y=None): + """ + Nothing to be done, but useful when included in a scikit-learn Pipeline. + """ + return self + + def __transform(self, cells): + cubical_complex = CubicalComplex(top_dimensional_cells=cells) + cubical_complex.compute_persistence( + homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence + ) + return [ + cubical_complex.persistence_intervals_in_dimension(dim) for dim in self.homology_dimensions + ] + + def __transform_only_this_dim(self, cells): + cubical_complex = CubicalComplex(top_dimensional_cells=cells) + cubical_complex.compute_persistence( + homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence + ) + return cubical_complex.persistence_intervals_in_dimension(self.homology_dimensions) + + def transform(self, X, Y=None): + """Compute all the cubical complexes and their associated persistence diagrams. + + :param X: List of cells filtration values (`numpy.reshape(X, newshape, order='C'` if `newshape` is set with a tuple of ints). + :type X: list of list of float OR list of numpy.ndarray + + :return: Persistence diagrams in the format: + + - If `homology_dimensions` was set to `n`: `[array( Hn(X[0]) ), array( Hn(X[1]) ), ...]` + - If `homology_dimensions` was set to `[i, j]`: `[[array( Hi(X[0]) ), array( Hj(X[0]) )], [array( Hi(X[1]) ), array( Hj(X[1]) )], ...]` + :rtype: list of (,2) array_like or list of list of (,2) array_like + """ + if self.newshape is not None: + X = np.reshape(X, self.newshape, order='C') + + # Depends on homology_dimensions is an integer or a list of integer (else case) + if isinstance(self.homology_dimensions, int): + # threads is preferred as cubical construction and persistence computation releases the GIL + return Parallel(n_jobs=self.n_jobs, prefer="threads")( + delayed(self.__transform_only_this_dim)(cells) for cells in X + ) + else: + # threads is preferred as cubical construction and persistence computation releases the GIL + return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(cells) for cells in X) + diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py index d219ce7a..4a455bb6 100755 --- a/src/python/test/test_representations.py +++ b/src/python/test/test_representations.py @@ -152,7 +152,26 @@ def test_vectorization_empty_diagrams(): scv = Entropy(mode="vector", normalized=False, resolution=random_resolution)(empty_diag) assert not np.any(scv) assert scv.shape[0] == random_resolution - + +def test_entropy_miscalculation(): + diag_ex = np.array([[0.0,1.0], [0.0,1.0], [0.0,2.0]]) + def pe(pd): + l = pd[:,1] - pd[:,0] + l = l/sum(l) + return -np.dot(l, np.log(l)) + sce = Entropy(mode="scalar") + assert [[pe(diag_ex)]] == sce.fit_transform([diag_ex]) + sce = Entropy(mode="vector", resolution=4, normalized=False) + pef = [-1/4*np.log(1/4)-1/4*np.log(1/4)-1/2*np.log(1/2), + -1/4*np.log(1/4)-1/4*np.log(1/4)-1/2*np.log(1/2), + -1/2*np.log(1/2), + 0.0] + assert all(([pef] == sce.fit_transform([diag_ex]))[0]) + sce = Entropy(mode="vector", resolution=4, normalized=True) + pefN = (sce.fit_transform([diag_ex]))[0] + area = np.linalg.norm(pefN, ord=1) + assert area==1 + def test_kernel_empty_diagrams(): empty_diag = np.empty(shape = [0, 2]) assert SlicedWassersteinDistance(num_directions=100)(empty_diag, empty_diag) == 0. diff --git a/src/python/test/test_representations_preprocessing.py b/src/python/test/test_representations_preprocessing.py new file mode 100644 index 00000000..838cf30c --- /dev/null +++ b/src/python/test/test_representations_preprocessing.py @@ -0,0 +1,39 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Vincent Rouvreau + + Copyright (C) 2021 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +from gudhi.representations.preprocessing import DimensionSelector +import numpy as np +import pytest + +H0_0 = np.array([0.0, 0.0]) +H1_0 = np.array([1.0, 0.0]) +H0_1 = np.array([0.0, 1.0]) +H1_1 = np.array([1.0, 1.0]) +H0_2 = np.array([0.0, 2.0]) +H1_2 = np.array([1.0, 2.0]) + + +def test_dimension_selector(): + X = [[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]] + ds = DimensionSelector(index=0) + h0 = ds.fit_transform(X) + np.testing.assert_array_equal(h0[0], H0_0) + np.testing.assert_array_equal(h0[1], H0_1) + np.testing.assert_array_equal(h0[2], H0_2) + + ds = DimensionSelector(index=1) + h1 = ds.fit_transform(X) + np.testing.assert_array_equal(h1[0], H1_0) + np.testing.assert_array_equal(h1[1], H1_1) + np.testing.assert_array_equal(h1[2], H1_2) + + ds = DimensionSelector(index=2) + with pytest.raises(IndexError): + h2 = ds.fit_transform([[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]]) diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py new file mode 100644 index 00000000..1c05a215 --- /dev/null +++ b/src/python/test/test_sklearn_cubical_persistence.py @@ -0,0 +1,59 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Vincent Rouvreau + + Copyright (C) 2021 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +from gudhi.sklearn.cubical_persistence import CubicalPersistence +import numpy as np +from sklearn import datasets + +CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0.0, 6.0], [0.0, 8.0], [0.0, np.inf]]) + + +def test_simple_constructor_from_top_cells(): + cells = datasets.load_digits().images[0] + cp = CubicalPersistence(homology_dimensions=0) + np.testing.assert_array_equal(cp._CubicalPersistence__transform_only_this_dim(cells), CUBICAL_PERSISTENCE_H0_IMG0) + cp = CubicalPersistence(homology_dimensions=[0, 2]) + diags = cp._CubicalPersistence__transform(cells) + assert len(diags) == 2 + np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0) + + +def test_simple_constructor_from_top_cells_list(): + digits = datasets.load_digits().images[:10] + cp = CubicalPersistence(homology_dimensions=0, n_jobs=-2) + + diags = cp.fit_transform(digits) + assert len(diags) == 10 + np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0) + + cp = CubicalPersistence(homology_dimensions=[0, 1], n_jobs=-1) + diagsH0H1 = cp.fit_transform(digits) + assert len(diagsH0H1) == 10 + for idx in range(10): + np.testing.assert_array_equal(diags[idx], diagsH0H1[idx][0]) + +def test_simple_constructor_from_flattened_cells(): + cells = datasets.load_digits().images[0] + # Not squared (extended) flatten cells + flat_cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten() + + cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10]) + diags = cp.fit_transform([flat_cells]) + + np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0) + + # Not squared (extended) non-flatten cells + cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))) + + # The aim of this second part of the test is to resize even if not mandatory + cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10]) + diags = cp.fit_transform([cells]) + + np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0) |