summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2022-08-17 14:31:07 +0200
committerGard Spreemann <gspr@nonempty.org>2022-08-17 14:31:07 +0200
commite65c2f6750eca50a2333801a25f401b318e49ef7 (patch)
treeea24364b96a3dafe237d21d836a0857675d257b1 /src/python
parentdbc404626955aee632fa47ee7a4d4c3add7d6188 (diff)
parentde8bd5109fcdc6d4d200c74685bab031d953d2af (diff)
Merge tag 'tags/gudhi-release-3.6.0' into dfsg/latest
Diffstat (limited to 'src/python')
-rw-r--r--src/python/CMakeLists.txt81
-rw-r--r--src/python/doc/alpha_complex_ref.rst1
-rw-r--r--src/python/doc/alpha_complex_sum.inc24
-rw-r--r--src/python/doc/alpha_complex_user.rst107
-rw-r--r--src/python/doc/cubical_complex_sklearn_itf_ref.rst102
-rw-r--r--src/python/doc/cubical_complex_sum.inc30
-rw-r--r--src/python/doc/cubical_complex_tflow_itf_ref.rst40
-rw-r--r--src/python/doc/cubical_complex_user.rst11
-rw-r--r--src/python/doc/datasets.inc (renamed from src/python/doc/datasets_generators.inc)4
-rw-r--r--src/python/doc/datasets.rst (renamed from src/python/doc/datasets_generators.rst)36
-rw-r--r--src/python/doc/differentiation_sum.inc12
-rw-r--r--src/python/doc/img/bunny.pngbin0 -> 48040 bytes
-rw-r--r--src/python/doc/img/sklearn.pngbin0 -> 9368 bytes
-rw-r--r--src/python/doc/img/spiral_2d.pngbin0 -> 279276 bytes
-rw-r--r--src/python/doc/img/tensorflow.pngbin0 -> 3846 bytes
-rw-r--r--src/python/doc/index.rst6
-rw-r--r--src/python/doc/installation.rst105
-rw-r--r--src/python/doc/ls_simplex_tree_tflow_itf_ref.rst53
-rw-r--r--src/python/doc/nerve_gic_complex_user.rst2
-rw-r--r--src/python/doc/persistence_graphical_tools_user.rst2
-rw-r--r--src/python/doc/persistent_cohomology_user.rst29
-rw-r--r--src/python/doc/rips_complex_sum.inc7
-rw-r--r--src/python/doc/rips_complex_tflow_itf_ref.rst48
-rw-r--r--src/python/doc/rips_complex_user.rst8
-rw-r--r--src/python/doc/simplex_tree_sum.inc25
-rwxr-xr-xsrc/python/example/alpha_complex_diagram_persistence_from_off_file_example.py55
-rwxr-xr-xsrc/python/example/alpha_rips_persistence_bottleneck_distance.py110
-rwxr-xr-xsrc/python/example/plot_alpha_complex.py5
-rwxr-xr-xsrc/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py2
-rw-r--r--src/python/gudhi/__init__.py.in4
-rw-r--r--src/python/gudhi/alpha_complex.pyx91
-rw-r--r--src/python/gudhi/datasets/remote.py223
-rw-r--r--src/python/gudhi/hera/wasserstein.cc2
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py349
-rw-r--r--src/python/gudhi/representations/preprocessing.py57
-rw-r--r--src/python/gudhi/representations/vector_methods.py177
-rw-r--r--src/python/gudhi/simplex_tree.pxd9
-rw-r--r--src/python/gudhi/simplex_tree.pyx89
-rw-r--r--src/python/gudhi/sklearn/__init__.py0
-rw-r--r--src/python/gudhi/sklearn/cubical_persistence.py110
-rw-r--r--src/python/gudhi/tensorflow/__init__.py5
-rw-r--r--src/python/gudhi/tensorflow/cubical_layer.py82
-rw-r--r--src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py87
-rw-r--r--src/python/gudhi/tensorflow/rips_layer.py93
-rw-r--r--src/python/gudhi/wasserstein/barycenter.py6
-rw-r--r--src/python/gudhi/weighted_rips_complex.py6
-rw-r--r--src/python/include/Alpha_complex_factory.h118
-rw-r--r--src/python/include/Alpha_complex_interface.h62
-rw-r--r--src/python/include/Persistent_cohomology_interface.h40
-rw-r--r--src/python/include/Simplex_tree_interface.h46
-rw-r--r--src/python/pyproject.toml2
-rw-r--r--src/python/setup.py.in4
-rwxr-xr-xsrc/python/test/test_alpha_complex.py179
-rwxr-xr-xsrc/python/test/test_betti_curve_representations.py59
-rw-r--r--src/python/test/test_diff.py78
-rwxr-xr-xsrc/python/test/test_dtm.py16
-rw-r--r--src/python/test/test_persistence_graphical_tools.py121
-rwxr-xr-xsrc/python/test/test_reader_utils.py33
-rw-r--r--src/python/test/test_remote_datasets.py87
-rwxr-xr-xsrc/python/test/test_representations.py24
-rw-r--r--src/python/test/test_representations_preprocessing.py39
-rwxr-xr-xsrc/python/test/test_simplex_tree.py135
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py59
-rwxr-xr-xsrc/python/test/test_subsampling.py4
64 files changed, 2636 insertions, 765 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 1314b444..5f323935 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -70,6 +70,7 @@ if(PYTHONINTERP_FOUND)
set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'euclidean_strong_witness_complex', ")
# Modules that should not be auto-imported in __init__.py
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'representations', ")
+ set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'tensorflow', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'wasserstein', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'point_cloud', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'weighted_rips_complex', ")
@@ -148,10 +149,6 @@ if(PYTHONINTERP_FOUND)
add_gudhi_debug_info("Eigen3 version ${EIGEN3_VERSION}")
# No problem, even if no CGAL found
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DCGAL_EIGEN3_ENABLED', ")
- set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DGUDHI_USE_EIGEN3', ")
- set(GUDHI_USE_EIGEN3 "True")
- else (EIGEN3_FOUND)
- set(GUDHI_USE_EIGEN3 "False")
endif (EIGEN3_FOUND)
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'off_reader', ")
@@ -171,13 +168,25 @@ if(PYTHONINTERP_FOUND)
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'nerve_gic', ")
endif ()
if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
- set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'alpha_complex', ")
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'subsampling', ")
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'tangential_complex', ")
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'euclidean_witness_complex', ")
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'euclidean_strong_witness_complex', ")
endif ()
+ if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
+ set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'alpha_complex', ")
+ endif ()
+ # from windows vcpkg eigen 3.4.0#2 : build fails with
+ # error C2440: '<function-style-cast>': cannot convert from 'Eigen::EigenBase<Derived>::Index' to '__gmp_expr<mpq_t,mpq_t>'
+ # cf. https://gitlab.com/libeigen/eigen/-/issues/2476
+ # Workaround is to compile with '-DEIGEN_DEFAULT_DENSE_INDEX_TYPE=int'
+ if (FORCE_EIGEN_DEFAULT_DENSE_INDEX_TYPE_TO_INT)
+ set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DEIGEN_DEFAULT_DENSE_INDEX_TYPE=int', ")
+ endif()
+
+
+ add_gudhi_debug_info("Boost version ${Boost_VERSION}")
if(CGAL_FOUND)
if(NOT CGAL_VERSION VERSION_LESS 5.3.0)
# CGAL_HEADER_ONLY has been dropped for CGAL >= 5.3. Only the header-only version is supported.
@@ -212,13 +221,14 @@ if(PYTHONINTERP_FOUND)
endif(NOT GMP_LIBRARIES_DIR)
add_GUDHI_PYTHON_lib_dir(${GMP_LIBRARIES_DIR})
message("** Add gmp ${GMP_LIBRARIES_DIR}")
+ # When FORCE_CGAL_NOT_TO_BUILD_WITH_GMPXX is set, not defining CGAL_USE_GMPXX is sufficient enough
if(GMPXX_FOUND)
add_gudhi_debug_info("GMPXX_LIBRARIES = ${GMPXX_LIBRARIES}")
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DCGAL_USE_GMPXX', ")
add_GUDHI_PYTHON_lib("${GMPXX_LIBRARIES}")
add_GUDHI_PYTHON_lib_dir(${GMPXX_LIBRARIES_DIR})
message("** Add gmpxx ${GMPXX_LIBRARIES_DIR}")
- endif(GMPXX_FOUND)
+ endif()
endif(GMP_FOUND)
if(MPFR_FOUND)
add_gudhi_debug_info("MPFR_LIBRARIES = ${MPFR_LIBRARIES}")
@@ -262,6 +272,10 @@ if(PYTHONINTERP_FOUND)
set(GUDHI_PYTHON_INCLUDE_DIRS "${GUDHI_PYTHON_INCLUDE_DIRS}'${TBB_INCLUDE_DIRS}', ")
endif()
+ if(DEBUG_TRACES)
+ set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DDEBUG_TRACES', ")
+ endif(DEBUG_TRACES)
+
if(UNIX AND WITH_GUDHI_PYTHON_RUNTIME_LIBRARY_DIRS)
set( GUDHI_PYTHON_RUNTIME_LIBRARY_DIRS "${GUDHI_PYTHON_LIBRARY_DIRS}")
endif(UNIX AND WITH_GUDHI_PYTHON_RUNTIME_LIBRARY_DIRS)
@@ -276,13 +290,15 @@ if(PYTHONINTERP_FOUND)
# Other .py files
file(COPY "gudhi/persistence_graphical_tools.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/representations" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/")
- file(COPY "gudhi/wasserstein" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
+ file(COPY "gudhi/wasserstein" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
+ file(COPY "gudhi/tensorflow" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/point_cloud" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/clustering" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi" FILES_MATCHING PATTERN "*.py")
file(COPY "gudhi/weighted_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/dtm_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/hera/__init__.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/hera")
file(COPY "gudhi/datasets" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi" FILES_MATCHING PATTERN "*.py")
+ file(COPY "gudhi/sklearn" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/")
# Some files for pip package
@@ -313,12 +329,12 @@ if(PYTHONINTERP_FOUND)
if(SCIPY_FOUND)
if(SKLEARN_FOUND)
if(OT_FOUND)
- if(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
+ if(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
set (GUDHI_SPHINX_MESSAGE "Generating API documentation with Sphinx in ${CMAKE_CURRENT_BINARY_DIR}/sphinx/")
# User warning - Sphinx is a static pages generator, and configured to work fine with user_version
- # Images and biblio warnings because not found on developper version
+ # Images and biblio warnings because not found on developer version
if (GUDHI_PYTHON_PATH STREQUAL "src/python")
- set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developper version. Images and biblio will miss")
+ set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developer version. Images and biblio will miss")
endif()
# sphinx target requires gudhi.so, because conf.py reads gudhi version from it
add_custom_target(sphinx
@@ -333,10 +349,10 @@ if(PYTHONINTERP_FOUND)
${SPHINX_PATH} -b doctest ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/doctest)
# Set missing or not modules
set(GUDHI_MODULES ${GUDHI_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MODULES")
- else(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
- message("++ Python documentation module will not be compiled because it requires a Eigen3 and CGAL version >= 4.11.0")
+ else(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
+ message("++ Python documentation module will not be compiled because it requires a Eigen3 and CGAL version >= 5.1.0")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
- endif(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
+ endif(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
else(OT_FOUND)
message("++ Python documentation module will not be compiled because POT was not found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
@@ -372,13 +388,15 @@ if(PYTHONINTERP_FOUND)
# Test examples
- if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
+ if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
# Bottleneck and Alpha
add_test(NAME alpha_rips_persistence_bottleneck_distance_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/alpha_rips_persistence_bottleneck_distance.py"
-f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -t 0.15 -d 3)
+ endif (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
+ if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
# Tangential
add_test(NAME tangential_complex_plain_homology_from_off_file_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
@@ -446,7 +464,7 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_cover_complex)
endif (NOT CGAL_VERSION VERSION_LESS 4.11.0)
- if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
+ if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
# Alpha
add_test(NAME alpha_complex_from_points_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
@@ -462,14 +480,14 @@ if(PYTHONINTERP_FOUND)
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/alpha_complex_diagram_persistence_from_off_file_example.py"
--no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off)
add_gudhi_py_test(test_alpha_complex)
- endif (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
+ endif (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
# Euclidean witness
add_gudhi_py_test(test_euclidean_witness_complex)
# Datasets generators
- add_gudhi_py_test(test_datasets_generators) # TODO separate full python datasets generators in another test file independant from CGAL ?
+ add_gudhi_py_test(test_datasets_generators) # TODO separate full python datasets generators in another test file independent from CGAL ?
endif (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
@@ -547,6 +565,21 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_representations)
endif()
+ # Differentiation
+ if(TENSORFLOW_FOUND)
+ add_gudhi_py_test(test_diff)
+ endif()
+
+ # Betti curves
+ if(SKLEARN_FOUND AND SCIPY_FOUND)
+ add_gudhi_py_test(test_betti_curve_representations)
+ endif()
+
+ # Representations preprocessing
+ if(SKLEARN_FOUND)
+ add_gudhi_py_test(test_representations_preprocessing)
+ endif()
+
# Time Delay
add_gudhi_py_test(test_time_delay)
@@ -571,6 +604,20 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_dtm_rips_complex)
endif()
+ # Fetch remote datasets
+ if(WITH_GUDHI_REMOTE_TEST)
+ add_gudhi_py_test(test_remote_datasets)
+ endif()
+
+ # sklearn
+ if(SKLEARN_FOUND)
+ add_gudhi_py_test(test_sklearn_cubical_persistence)
+ endif()
+
+ # persistence graphical tools
+ if(MATPLOTLIB_FOUND)
+ add_gudhi_py_test(test_persistence_graphical_tools)
+ endif()
# Set missing or not modules
set(GUDHI_MODULES ${GUDHI_MODULES} "python" CACHE INTERNAL "GUDHI_MODULES")
diff --git a/src/python/doc/alpha_complex_ref.rst b/src/python/doc/alpha_complex_ref.rst
index 7da79543..eaa72551 100644
--- a/src/python/doc/alpha_complex_ref.rst
+++ b/src/python/doc/alpha_complex_ref.rst
@@ -9,6 +9,5 @@ Alpha complex reference manual
.. autoclass:: gudhi.AlphaComplex
:members:
:undoc-members:
- :show-inheritance:
.. automethod:: gudhi.AlphaComplex.__init__
diff --git a/src/python/doc/alpha_complex_sum.inc b/src/python/doc/alpha_complex_sum.inc
index aeab493f..5c76fd54 100644
--- a/src/python/doc/alpha_complex_sum.inc
+++ b/src/python/doc/alpha_complex_sum.inc
@@ -1,15 +1,15 @@
.. table::
:widths: 30 40 30
- +----------------------------------------------------------------+-------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
- | .. figure:: | Alpha complex is a simplicial complex constructed from the finite | :Author: Vincent Rouvreau |
- | ../../doc/Alpha_complex/alpha_complex_representation.png | cells of a Delaunay Triangulation. It has the same persistent homology | |
- | :alt: Alpha complex representation | as the Čech complex and is significantly smaller. | :Since: GUDHI 2.0.0 |
- | :figclass: align-center | | |
- | | | :License: MIT (`GPL v3 </licensing/>`_) |
- | | | |
- | | | :Requires: `Eigen <installation.html#eigen>`_ :math:`\geq` 3.1.0 and `CGAL <installation.html#cgal>`_ :math:`\geq` 4.11.0 |
- | | | |
- +----------------------------------------------------------------+-------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
- | * :doc:`alpha_complex_user` | * :doc:`alpha_complex_ref` |
- +----------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
+ +----------------------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+
+ | .. figure:: | Alpha complex is a simplicial complex constructed from the finite | :Author: Vincent Rouvreau |
+ | ../../doc/Alpha_complex/alpha_complex_representation.png | cells of a Delaunay Triangulation. It has the same persistent homology | |
+ | :alt: Alpha complex representation | as the Čech complex and is significantly smaller. | :Since: GUDHI 2.0.0 |
+ | :figclass: align-center | | |
+ | | | :License: MIT (`GPL v3 </licensing/>`_) |
+ | | | |
+ | | | :Requires: `Eigen <installation.html#eigen>`_ :math:`\geq` 3.1.0 and `CGAL <installation.html#cgal>`_ :math:`\geq` 5.1 |
+ | | | |
+ +----------------------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+
+ | * :doc:`alpha_complex_user` | * :doc:`alpha_complex_ref` |
+ +----------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
diff --git a/src/python/doc/alpha_complex_user.rst b/src/python/doc/alpha_complex_user.rst
index 96e267ef..9e67d38a 100644
--- a/src/python/doc/alpha_complex_user.rst
+++ b/src/python/doc/alpha_complex_user.rst
@@ -9,7 +9,7 @@ Definition
.. include:: alpha_complex_sum.inc
-:doc:`AlphaComplex <alpha_complex_ref>` is constructing a :doc:`SimplexTree <simplex_tree_ref>` using
+:class:`~gudhi.AlphaComplex` is constructing a :doc:`SimplexTree <simplex_tree_ref>` using
`Delaunay Triangulation <http://doc.cgal.org/latest/Triangulation/index.html#Chapter_Triangulations>`_
:cite:`cgal:hdj-t-19b` from the `Computational Geometry Algorithms Library <http://www.cgal.org/>`_
:cite:`cgal:eb-19b`.
@@ -27,15 +27,13 @@ Remarks
If you pass :code:`precision = 'exact'` to :func:`~gudhi.AlphaComplex.__init__`, the filtration values are the exact
ones converted to float. This can be very slow.
If you pass :code:`precision = 'safe'` (the default), the filtration values are only
- guaranteed to have a small multiplicative error compared to the exact value.
+ guaranteed to have a small multiplicative error compared to the exact value, see
+ :func:`~gudhi.AlphaComplex.set_float_relative_precision` to modify the precision.
A drawback, when computing persistence, is that an empty exact interval [10^12,10^12] may become a
non-empty approximate interval [10^12,10^12+10^6].
Using :code:`precision = 'fast'` makes the computations slightly faster, and the combinatorics are still exact, but
the computation of filtration values can exceptionally be arbitrarily bad. In all cases, we still guarantee that the
output is a valid filtration (faces have a filtration value no larger than their cofaces).
-* For performances reasons, it is advised to use Alpha_complex with `CGAL <installation.html#cgal>`_ :math:`\geq` 5.0.0.
-* The vertices in the output simplex tree are not guaranteed to match the order of the input points. One can use
- :func:`~gudhi.AlphaComplex.get_point` to get the initial point back.
Example from points
-------------------
@@ -44,23 +42,22 @@ This example builds the alpha-complex from the given points:
.. testcode::
- import gudhi
- alpha_complex = gudhi.AlphaComplex(points=[[1, 1], [7, 0], [4, 6], [9, 6], [0, 14], [2, 19], [9, 17]])
+ from gudhi import AlphaComplex
+ ac = AlphaComplex(points=[[1, 1], [7, 0], [4, 6], [9, 6], [0, 14], [2, 19], [9, 17]])
+
+ stree = ac.create_simplex_tree()
+ print('Alpha complex is of dimension ', stree.dimension(), ' - ',
+ stree.num_simplices(), ' simplices - ', stree.num_vertices(), ' vertices.')
- simplex_tree = alpha_complex.create_simplex_tree()
- result_str = 'Alpha complex is of dimension ' + repr(simplex_tree.dimension()) + ' - ' + \
- repr(simplex_tree.num_simplices()) + ' simplices - ' + \
- repr(simplex_tree.num_vertices()) + ' vertices.'
- print(result_str)
fmt = '%s -> %.2f'
- for filtered_value in simplex_tree.get_filtration():
+ for filtered_value in stree.get_filtration():
print(fmt % tuple(filtered_value))
The output is:
.. testoutput::
- Alpha complex is of dimension 2 - 25 simplices - 7 vertices.
+ Alpha complex is of dimension 2 - 25 simplices - 7 vertices.
[0] -> 0.00
[1] -> 0.00
[2] -> 0.00
@@ -177,11 +174,75 @@ of speed-up, since we always first build the full filtered complex, so it is rec
:paramref:`~gudhi.AlphaComplex.create_simplex_tree.max_alpha_square`.
In the following example, a threshold of :math:`\alpha^2 = 32.0` is used.
+Weighted version
+^^^^^^^^^^^^^^^^
+
+A weighted version for Alpha complex is available. It is like a usual Alpha complex, but based on a
+`CGAL regular triangulation <https://doc.cgal.org/latest/Triangulation/index.html#TriangulationSecRT>`_.
+
+This example builds the weighted alpha-complex of a small molecule, where atoms have different sizes.
+It is taken from
+`CGAL 3d weighted alpha shapes <https://doc.cgal.org/latest/Alpha_shapes_3/index.html#AlphaShape_3DExampleforWeightedAlphaShapes>`_.
+
+Then, it is asked to display information about the alpha complex.
+
+.. testcode::
+
+ from gudhi import AlphaComplex
+ wgt_ac = AlphaComplex(points=[[ 1., -1., -1.],
+ [-1., 1., -1.],
+ [-1., -1., 1.],
+ [ 1., 1., 1.],
+ [ 2., 2., 2.]],
+ weights = [4., 4., 4., 4., 1.])
+
+ stree = wgt_ac.create_simplex_tree()
+ print('Weighted alpha complex is of dimension ', stree.dimension(), ' - ',
+ stree.num_simplices(), ' simplices - ', stree.num_vertices(), ' vertices.')
+ fmt = '%s -> %.2f'
+ for simplex in stree.get_simplices():
+ print(fmt % tuple(simplex))
+
+The output is:
+
+.. testoutput::
+
+ Weighted alpha complex is of dimension 3 - 29 simplices - 5 vertices.
+ [0, 1, 2, 3] -> -1.00
+ [0, 1, 2] -> -1.33
+ [0, 1, 3, 4] -> 95.00
+ [0, 1, 3] -> -1.33
+ [0, 1, 4] -> 95.00
+ [0, 1] -> -2.00
+ [0, 2, 3, 4] -> 95.00
+ [0, 2, 3] -> -1.33
+ [0, 2, 4] -> 95.00
+ [0, 2] -> -2.00
+ [0, 3, 4] -> 23.00
+ [0, 3] -> -2.00
+ [0, 4] -> 23.00
+ [0] -> -4.00
+ [1, 2, 3, 4] -> 95.00
+ [1, 2, 3] -> -1.33
+ [1, 2, 4] -> 95.00
+ [1, 2] -> -2.00
+ [1, 3, 4] -> 23.00
+ [1, 3] -> -2.00
+ [1, 4] -> 23.00
+ [1] -> -4.00
+ [2, 3, 4] -> 23.00
+ [2, 3] -> -2.00
+ [2, 4] -> 23.00
+ [2] -> -4.00
+ [3, 4] -> -1.00
+ [3] -> -4.00
+ [4] -> -1.00
Example from OFF file
^^^^^^^^^^^^^^^^^^^^^
-This example builds the alpha complex from 300 random points on a 2-torus.
+This example builds the alpha complex from 300 random points on a 2-torus, given by an
+`OFF file <fileformats.html#off-file-format>`_.
Then, it computes the persistence diagram and displays it:
@@ -189,14 +250,10 @@ Then, it computes the persistence diagram and displays it:
:include-source:
import matplotlib.pyplot as plt
- import gudhi
- alpha_complex = gudhi.AlphaComplex(off_file=gudhi.__root_source_dir__ + \
- '/data/points/tore3D_300.off')
- simplex_tree = alpha_complex.create_simplex_tree()
- result_str = 'Alpha complex is of dimension ' + repr(simplex_tree.dimension()) + ' - ' + \
- repr(simplex_tree.num_simplices()) + ' simplices - ' + \
- repr(simplex_tree.num_vertices()) + ' vertices.'
- print(result_str)
- diag = simplex_tree.persistence()
- gudhi.plot_persistence_diagram(diag)
+ import gudhi as gd
+ off_file = gd.__root_source_dir__ + '/data/points/tore3D_300.off'
+ points = gd.read_points_from_off_file(off_file = off_file)
+ stree = gd.AlphaComplex(points = points).create_simplex_tree()
+ dgm = stree.persistence()
+ gd.plot_persistence_diagram(dgm, legend = True)
plt.show()
diff --git a/src/python/doc/cubical_complex_sklearn_itf_ref.rst b/src/python/doc/cubical_complex_sklearn_itf_ref.rst
new file mode 100644
index 00000000..2fb8ec6a
--- /dev/null
+++ b/src/python/doc/cubical_complex_sklearn_itf_ref.rst
@@ -0,0 +1,102 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+Cubical complex persistence scikit-learn like interface
+#######################################################
+
+.. list-table::
+ :width: 100%
+ :header-rows: 0
+
+ * - :Since: GUDHI 3.6.0
+ - :License: MIT
+ - :Requires: `Scikit-learn <installation.html#scikit-learn>`_
+
+Cubical complex persistence scikit-learn like interface example
+---------------------------------------------------------------
+
+In this example, hand written digits are used as an input.
+a TDA scikit-learn pipeline is constructed and is composed of:
+
+#. :class:`~gudhi.sklearn.cubical_persistence.CubicalPersistence` that builds a cubical complex from the inputs and
+ returns its persistence diagrams
+#. :class:`~gudhi.representations.preprocessing.DiagramSelector` that removes non-finite persistence diagrams values
+#. :class:`~gudhi.representations.vector_methods.PersistenceImage` that builds the persistence images from persistence diagrams
+#. `SVC <https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html>`_ which is a scikit-learn support
+ vector classifier.
+
+This ML pipeline is trained to detect if the hand written digit is an '8' or not, thanks to the fact that an '8' has
+two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected components in :math:`\mathbf{H}_0`.
+
+.. code-block:: python
+
+ # Standard scientific Python imports
+ import numpy as np
+
+ # Standard scikit-learn imports
+ from sklearn.datasets import fetch_openml
+ from sklearn.pipeline import Pipeline
+ from sklearn.model_selection import train_test_split
+ from sklearn.svm import SVC
+ from sklearn import metrics
+
+ # Import TDA pipeline requirements
+ from gudhi.sklearn.cubical_persistence import CubicalPersistence
+ from gudhi.representations import PersistenceImage, DiagramSelector
+
+ X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
+
+ # Target is: "is an eight ?"
+ y = (y == "8") * 1
+ print("There are", np.sum(y), "eights out of", len(y), "numbers.")
+
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
+ pipe = Pipeline(
+ [
+ ("cub_pers", CubicalPersistence(homology_dimensions=0, newshape=[28, 28], n_jobs=-2)),
+ # Or for multiple persistence dimension computation
+ # ("cub_pers", CubicalPersistence(homology_dimensions=[0, 1], newshape=[28, 28], n_jobs=-2)),
+ # ("H0_diags", DimensionSelector(index=0), # where index is the index in homology_dimensions array
+ ("finite_diags", DiagramSelector(use=True, point_type="finite")),
+ (
+ "pers_img",
+ PersistenceImage(bandwidth=50, weight=lambda x: x[1] ** 2, im_range=[0, 256, 0, 256], resolution=[20, 20]),
+ ),
+ ("svc", SVC()),
+ ]
+ )
+
+ # Learn from the train subset
+ pipe.fit(X_train, y_train)
+ # Predict from the test subset
+ predicted = pipe.predict(X_test)
+
+ print(f"Classification report for TDA pipeline {pipe}:\n" f"{metrics.classification_report(y_test, predicted)}\n")
+
+.. code-block:: none
+
+ There are 6825 eights out of 70000 numbers.
+ Classification report for TDA pipeline Pipeline(steps=[('cub_pers',
+ CubicalPersistence(newshape=[28, 28], n_jobs=-2)),
+ ('finite_diags', DiagramSelector(use=True)),
+ ('pers_img',
+ PersistenceImage(bandwidth=50, im_range=[0, 256, 0, 256],
+ weight=<function <lambda> at 0x7f3e54137ae8>)),
+ ('svc', SVC())]):
+ precision recall f1-score support
+
+ 0 0.97 0.99 0.98 25284
+ 1 0.92 0.68 0.78 2716
+
+ accuracy 0.96 28000
+ macro avg 0.94 0.84 0.88 28000
+ weighted avg 0.96 0.96 0.96 28000
+
+Cubical complex persistence scikit-learn like interface reference
+-----------------------------------------------------------------
+
+.. autoclass:: gudhi.sklearn.cubical_persistence.CubicalPersistence
+ :members:
+ :special-members: __init__
+ :show-inheritance: \ No newline at end of file
diff --git a/src/python/doc/cubical_complex_sum.inc b/src/python/doc/cubical_complex_sum.inc
index 87db184d..b27843e5 100644
--- a/src/python/doc/cubical_complex_sum.inc
+++ b/src/python/doc/cubical_complex_sum.inc
@@ -1,14 +1,22 @@
.. table::
:widths: 30 40 30
- +--------------------------------------------------------------------------+----------------------------------------------------------------------+-----------------------------+
- | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko |
- | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | |
- | :alt: Cubical complex representation | | :Since: GUDHI 2.0.0 |
- | :figclass: align-center | | |
- | | | :License: MIT |
- | | | |
- +--------------------------------------------------------------------------+----------------------------------------------------------------------+-----------------------------+
- | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
- | | * :doc:`periodic_cubical_complex_ref` |
- +--------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko |
+ | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | :Since: GUDHI 2.0.0 |
+ | :alt: Cubical complex representation | | :License: MIT |
+ | :figclass: align-center | | |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
+ | | * :doc:`periodic_cubical_complex_ref` |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | .. image:: | * :doc:`cubical_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
+ | img/tensorflow.png | | |
+ | :target: https://www.tensorflow.org | | |
+ | :height: 30 | | |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | .. image:: | * :doc:`cubical_complex_sklearn_itf_ref` | :Requires: `Scikit-learn <installation.html#scikit-learn>`_ |
+ | img/sklearn.png | | |
+ | :target: https://scikit-learn.org | | |
+ | :height: 30 | | |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
diff --git a/src/python/doc/cubical_complex_tflow_itf_ref.rst b/src/python/doc/cubical_complex_tflow_itf_ref.rst
new file mode 100644
index 00000000..b32f5e47
--- /dev/null
+++ b/src/python/doc/cubical_complex_tflow_itf_ref.rst
@@ -0,0 +1,40 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+TensorFlow layer for cubical persistence
+########################################
+
+.. include:: differentiation_sum.inc
+
+Example of gradient computed from cubical persistence
+-----------------------------------------------------
+
+.. testcode::
+
+ from gudhi.tensorflow import CubicalLayer
+ import tensorflow as tf
+
+ X = tf.Variable([[0.,2.,2.],[2.,2.,2.],[2.,2.,1.]], dtype=tf.float32, trainable=True)
+ cl = CubicalLayer(homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = cl.call(X)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+
+ grads = tape.gradient(loss, [X])
+ print(grads[0].numpy())
+
+.. testoutput::
+
+ [[ 0. 0. 0. ]
+ [ 0. 0.5 0. ]
+ [ 0. 0. -0.5]]
+
+Documentation for CubicalLayer
+------------------------------
+
+.. autoclass:: gudhi.tensorflow.CubicalLayer
+ :members:
+ :special-members: __init__
+ :show-inheritance:
diff --git a/src/python/doc/cubical_complex_user.rst b/src/python/doc/cubical_complex_user.rst
index 6a211347..42a23875 100644
--- a/src/python/doc/cubical_complex_user.rst
+++ b/src/python/doc/cubical_complex_user.rst
@@ -7,14 +7,7 @@ Cubical complex user manual
Definition
----------
-===================================== ===================================== =====================================
-:Author: Pawel Dlotko :Since: GUDHI PYTHON 2.0.0 :License: GPL v3
-===================================== ===================================== =====================================
-
-+---------------------------------------------+----------------------------------------------------------------------+
-| :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
-| | * :doc:`periodic_cubical_complex_ref` |
-+---------------------------------------------+----------------------------------------------------------------------+
+.. include:: cubical_complex_sum.inc
The cubical complex is an example of a structured complex useful in computational mathematics (specially rigorous
numerics) and image analysis.
@@ -163,4 +156,4 @@ Tutorial
--------
This `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-cubical-complexes.ipynb>`_
-explains how to represent sublevels sets of functions using cubical complexes. \ No newline at end of file
+explains how to represent sublevels sets of functions using cubical complexes.
diff --git a/src/python/doc/datasets_generators.inc b/src/python/doc/datasets.inc
index 8d169275..95a87678 100644
--- a/src/python/doc/datasets_generators.inc
+++ b/src/python/doc/datasets.inc
@@ -2,7 +2,7 @@
:widths: 30 40 30
+-----------------------------------+--------------------------------------------+--------------------------------------------------------------------------------------+
- | .. figure:: | Datasets generators (points). | :Authors: Hind Montassif |
+ | .. figure:: | Datasets either generated or fetched. | :Authors: Hind Montassif |
| img/sphere_3d.png | | |
| | | :Since: GUDHI 3.5.0 |
| | | |
@@ -10,5 +10,5 @@
| | | |
| | | :Requires: `CGAL <installation.html#cgal>`_ |
+-----------------------------------+--------------------------------------------+--------------------------------------------------------------------------------------+
- | * :doc:`datasets_generators` |
+ | * :doc:`datasets` |
+-----------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+
diff --git a/src/python/doc/datasets_generators.rst b/src/python/doc/datasets.rst
index 260c3882..2d11a19d 100644
--- a/src/python/doc/datasets_generators.rst
+++ b/src/python/doc/datasets.rst
@@ -3,12 +3,14 @@
.. To get rid of WARNING: document isn't included in any toctree
-===========================
-Datasets generators manual
-===========================
+================
+Datasets manual
+================
-We provide the generation of different customizable datasets to use as inputs for Gudhi complexes and data structures.
+Datasets generators
+===================
+We provide the generation of different customizable datasets to use as inputs for Gudhi complexes and data structures.
Points generators
------------------
@@ -103,3 +105,29 @@ Example
.. autofunction:: gudhi.datasets.generators.points.torus
+
+
+Fetching datasets
+=================
+
+We provide some ready-to-use datasets that are not available by default when getting GUDHI, and need to be fetched explicitly.
+
+By **default**, the fetched datasets directory is set to a folder named **'gudhi_data'** in the **user home folder**.
+Alternatively, it can be set using the **'GUDHI_DATA'** environment variable.
+
+.. autofunction:: gudhi.datasets.remote.fetch_bunny
+
+.. figure:: ./img/bunny.png
+ :figclass: align-center
+
+ 3D Stanford bunny with 35947 vertices.
+
+
+.. autofunction:: gudhi.datasets.remote.fetch_spiral_2d
+
+.. figure:: ./img/spiral_2d.png
+ :figclass: align-center
+
+ 2D spiral with 114562 vertices.
+
+.. autofunction:: gudhi.datasets.remote.clear_data_home
diff --git a/src/python/doc/differentiation_sum.inc b/src/python/doc/differentiation_sum.inc
new file mode 100644
index 00000000..140cf180
--- /dev/null
+++ b/src/python/doc/differentiation_sum.inc
@@ -0,0 +1,12 @@
+.. list-table::
+ :width: 100%
+ :header-rows: 0
+
+ * - :Since: GUDHI 3.6.0
+ - :License: MIT
+ - :Requires: `TensorFlow <installation.html#tensorflow>`_
+
+We provide TensorFlow 2 models that can handle automatic differentiation for the computation of persistence diagrams from complexes available in the Gudhi library.
+This includes simplex trees, cubical complexes and Vietoris-Rips complexes. Detailed example on how to use these layers in practice are available
+in the following `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-optimization.ipynb>`_. Note that even if TensorFlow GPU is enabled, all
+internal computations using Gudhi will be done on CPU.
diff --git a/src/python/doc/img/bunny.png b/src/python/doc/img/bunny.png
new file mode 100644
index 00000000..769aa530
--- /dev/null
+++ b/src/python/doc/img/bunny.png
Binary files differ
diff --git a/src/python/doc/img/sklearn.png b/src/python/doc/img/sklearn.png
new file mode 100644
index 00000000..d1fecbbf
--- /dev/null
+++ b/src/python/doc/img/sklearn.png
Binary files differ
diff --git a/src/python/doc/img/spiral_2d.png b/src/python/doc/img/spiral_2d.png
new file mode 100644
index 00000000..abd247cd
--- /dev/null
+++ b/src/python/doc/img/spiral_2d.png
Binary files differ
diff --git a/src/python/doc/img/tensorflow.png b/src/python/doc/img/tensorflow.png
new file mode 100644
index 00000000..a75f3f5b
--- /dev/null
+++ b/src/python/doc/img/tensorflow.png
Binary files differ
diff --git a/src/python/doc/index.rst b/src/python/doc/index.rst
index 2d7921ae..35f4ba46 100644
--- a/src/python/doc/index.rst
+++ b/src/python/doc/index.rst
@@ -92,7 +92,7 @@ Clustering
.. include:: clustering.inc
-Datasets generators
-*******************
+Datasets
+********
-.. include:: datasets_generators.inc
+.. include:: datasets.inc
diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst
index 35c344e3..4eefd415 100644
--- a/src/python/doc/installation.rst
+++ b/src/python/doc/installation.rst
@@ -33,25 +33,19 @@ Compiling
These instructions are for people who want to compile gudhi from source, they are
unnecessary if you installed a binary package of Gudhi as above. They assume that
you have downloaded a `release <https://github.com/GUDHI/gudhi-devel/releases>`_,
-with a name like `gudhi.3.2.0.tar.gz`, then run `tar xf gudhi.3.2.0.tar.gz`, which
-created a directory `gudhi.3.2.0`, hereinafter referred to as `/path-to-gudhi/`.
+with a name like `gudhi.3.X.Y.tar.gz`, then run `tar xf gudhi.3.X.Y.tar.gz`, which
+created a directory `gudhi.3.X.Y`, hereinafter referred to as `/path-to-gudhi/`.
If you are instead using a git checkout, beware that the paths are a bit
different, and in particular the `python/` subdirectory is actually `src/python/`
there.
-The library uses c++14 and requires `Boost <https://www.boost.org/>`_ :math:`\geq` 1.56.0,
+The library uses c++14 and requires `Boost <https://www.boost.org/>`_ :math:`\geq` 1.66.0,
`CMake <https://www.cmake.org/>`_ :math:`\geq` 3.5 to generate makefiles,
-`NumPy <http://numpy.org>`_ :math:`\geq` 1.15.0, `Cython <https://www.cython.org/>`_ and
-`pybind11 <https://github.com/pybind/pybind11>`_ to compile
-the GUDHI Python module.
-It is a multi-platform library and compiles on Linux, Mac OSX and Visual
-Studio 2017 or later.
+Python :math:`\geq` 3.5, `NumPy <http://numpy.org>`_ :math:`\geq` 1.15.0, `Cython <https://www.cython.org/>`_
+:math:`\geq` 0.27 and `pybind11 <https://github.com/pybind/pybind11>`_ to compile the GUDHI Python module.
+It is a multi-platform library and compiles on Linux, Mac OSX and Visual Studio 2017 or later.
-On `Windows <https://wiki.python.org/moin/WindowsCompilers>`_ , only Python
-:math:`\geq` 3.5 are available because of the required Visual Studio version.
-
-On other systems, if you have several Python/python installed, the version 2.X
-will be used by default, but you can force it by adding
+If you have several Python/python installed, the version 2.X may be used by default, but you can force it by adding
:code:`-DPython_ADDITIONAL_VERSIONS=3` to the cmake command.
GUDHI Python module compilation
@@ -142,54 +136,63 @@ If :code:`import gudhi` succeeds, please have a look to debug information:
.. code-block:: python
- import gudhi
- print(gudhi.__debug_info__)
+ import gudhi as gd
+ print(gd.__debug_info__)
+ print("+ Installed modules are: " + gd.__available_modules)
+ print("+ Missing modules are: " + gd.__missing_modules)
You shall have something like:
.. code-block:: none
- Python version 2.7.15
- Cython version 0.26.1
- Numpy version 1.14.1
- Eigen3 version 3.1.1
- Installed modules are: off_reader;simplex_tree;rips_complex;
- cubical_complex;periodic_cubical_complex;reader_utils;witness_complex;
- strong_witness_complex;alpha_complex;
- Missing modules are: bottleneck_distance;nerve_gic;subsampling;
- tangential_complex;persistence_graphical_tools;
- euclidean_witness_complex;euclidean_strong_witness_complex;
- CGAL version 4.7.1000
- GMP_LIBRARIES = /usr/lib/x86_64-linux-gnu/libgmp.so
- GMPXX_LIBRARIES = /usr/lib/x86_64-linux-gnu/libgmpxx.so
- TBB version 9107 found and used
+ Pybind11 version 2.8.1
+ Python version 3.7.12
+ Cython version 0.29.25
+ Numpy version 1.21.4
+ Boost version 1.77.0
+ + Installed modules are: off_reader;simplex_tree;rips_complex;cubical_complex;periodic_cubical_complex;
+ persistence_graphical_tools;reader_utils;witness_complex;strong_witness_complex;
+ + Missing modules are: bottleneck;nerve_gic;subsampling;tangential_complex;alpha_complex;euclidean_witness_complex;
+ euclidean_strong_witness_complex;
-Here, you can see that bottleneck_distance, nerve_gic, subsampling and
-tangential_complex are missing because of the CGAL version.
-persistence_graphical_tools is not available as matplotlib is not
-available.
+Here, you can see that the modules that need CGAL are missing, because CGAL is not installed.
+:code:`persistence_graphical_tools` is installed, but
+`its functions <https://gudhi.inria.fr/python/latest/persistence_graphical_tools_ref.html>`_ will produce an error as
+matplotlib is not available.
Unitary tests cannot be run as pytest is missing.
A complete configuration would be :
.. code-block:: none
- Python version 3.6.5
- Cython version 0.28.2
- Pytest version 3.3.2
- Matplotlib version 2.2.2
- Numpy version 1.14.5
- Eigen3 version 3.3.4
- Installed modules are: off_reader;simplex_tree;rips_complex;
- cubical_complex;periodic_cubical_complex;persistence_graphical_tools;
- reader_utils;witness_complex;strong_witness_complex;
- persistence_graphical_tools;bottleneck_distance;nerve_gic;subsampling;
- tangential_complex;alpha_complex;euclidean_witness_complex;
- euclidean_strong_witness_complex;
- CGAL header only version 4.11.0
+ Pybind11 version 2.8.1
+ Python version 3.9.7
+ Cython version 0.29.24
+ Pytest version 6.2.5
+ Matplotlib version 3.5.0
+ Numpy version 1.21.4
+ Scipy version 1.7.3
+ Scikit-learn version 1.0.1
+ POT version 0.8.0
+ HNSWlib found
+ PyKeOps version [pyKeOps]: 2.1
+ EagerPy version 0.30.0
+ TensorFlow version 2.7.0
+ Sphinx version 4.3.0
+ Sphinx-paramlinks version 0.5.2
+ python_docs_theme found
+ Eigen3 version 3.4.0
+ Boost version 1.74.0
+ CGAL version 5.3
GMP_LIBRARIES = /usr/lib/x86_64-linux-gnu/libgmp.so
GMPXX_LIBRARIES = /usr/lib/x86_64-linux-gnu/libgmpxx.so
+ MPFR_LIBRARIES = /usr/lib/x86_64-linux-gnu/libmpfr.so
TBB version 9107 found and used
+ + Installed modules are: bottleneck;off_reader;simplex_tree;rips_complex;cubical_complex;periodic_cubical_complex;
+ persistence_graphical_tools;reader_utils;witness_complex;strong_witness_complex;nerve_gic;subsampling;
+ tangential_complex;alpha_complex;euclidean_witness_complex;euclidean_strong_witness_complex;
+ + Missing modules are:
+
Documentation
=============
@@ -345,8 +348,8 @@ You can still deactivate LaTeX rendering by saying:
.. code-block:: python
- import gudhi
- gudhi.persistence_graphical_tools._gudhi_matplotlib_use_tex=False
+ import gudhi as gd
+ gd.persistence_graphical_tools._gudhi_matplotlib_use_tex=False
PyKeOps
-------
@@ -393,7 +396,11 @@ mathematics, science, and engineering.
TensorFlow
----------
-`TensorFlow <https://www.tensorflow.org>`_ is currently only used in some automatic differentiation tests.
+The :doc:`cubical complex </cubical_complex_tflow_itf_ref>`, :doc:`simplex tree </ls_simplex_tree_tflow_itf_ref>`
+and :doc:`Rips complex </rips_complex_tflow_itf_ref>` modules require `TensorFlow <https://www.tensorflow.org>`_
+for incorporating them in neural nets.
+
+`TensorFlow <https://www.tensorflow.org>`_ is also used in some automatic differentiation tests.
Bug reports and contributions
*****************************
diff --git a/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
new file mode 100644
index 00000000..9d7d633f
--- /dev/null
+++ b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
@@ -0,0 +1,53 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+TensorFlow layer for lower-star persistence on simplex trees
+############################################################
+
+.. include:: differentiation_sum.inc
+
+Example of gradient computed from lower-star filtration of a simplex tree
+-------------------------------------------------------------------------
+
+.. testcode::
+
+ from gudhi.tensorflow import LowerStarSimplexTreeLayer
+ import tensorflow as tf
+ import gudhi as gd
+
+ st = gd.SimplexTree()
+ st.insert([0, 1])
+ st.insert([1, 2])
+ st.insert([2, 3])
+ st.insert([3, 4])
+ st.insert([4, 5])
+ st.insert([5, 6])
+ st.insert([6, 7])
+ st.insert([7, 8])
+ st.insert([8, 9])
+ st.insert([9, 10])
+
+ F = tf.Variable([6.,4.,3.,4.,5.,4.,3.,2.,3.,4.,5.], dtype=tf.float32, trainable=True)
+ sl = LowerStarSimplexTreeLayer(simplextree=st, homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = sl.call(F)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+
+ grads = tape.gradient(loss, [F])
+ print(grads[0].indices.numpy())
+ print(grads[0].values.numpy())
+
+.. testoutput::
+
+ [2 4]
+ [-1. 1.]
+
+Documentation for LowerStarSimplexTreeLayer
+-------------------------------------------
+
+.. autoclass:: gudhi.tensorflow.LowerStarSimplexTreeLayer
+ :members:
+ :special-members: __init__
+ :show-inheritance:
diff --git a/src/python/doc/nerve_gic_complex_user.rst b/src/python/doc/nerve_gic_complex_user.rst
index 0b820abf..8633cadb 100644
--- a/src/python/doc/nerve_gic_complex_user.rst
+++ b/src/python/doc/nerve_gic_complex_user.rst
@@ -12,7 +12,7 @@ Definition
Visualizations of the simplicial complexes can be done with either
neato (from `graphviz <http://www.graphviz.org/>`_),
`geomview <http://www.geomview.org/>`_,
-`KeplerMapper <https://github.com/MLWave/kepler-mapper>`_.
+`KeplerMapper <https://github.com/scikit-tda/kepler-mapper>`_.
Input point clouds are assumed to be OFF files (cf. `OFF file format <fileformats.html#off-file-format>`_).
Covers
diff --git a/src/python/doc/persistence_graphical_tools_user.rst b/src/python/doc/persistence_graphical_tools_user.rst
index d95b9d2b..e1d28c71 100644
--- a/src/python/doc/persistence_graphical_tools_user.rst
+++ b/src/python/doc/persistence_graphical_tools_user.rst
@@ -60,7 +60,7 @@ of shape (N x 2) encoding a persistence diagram (in a given dimension).
import matplotlib.pyplot as plt
import gudhi
import numpy as np
- d = np.array([[0, 1], [1, 2], [1, np.inf]])
+ d = np.array([[0., 1.], [1., 2.], [1., np.inf]])
gudhi.plot_persistence_diagram(d)
plt.show()
diff --git a/src/python/doc/persistent_cohomology_user.rst b/src/python/doc/persistent_cohomology_user.rst
index a3f294b2..39744b95 100644
--- a/src/python/doc/persistent_cohomology_user.rst
+++ b/src/python/doc/persistent_cohomology_user.rst
@@ -6,19 +6,24 @@ Persistent cohomology user manual
=================================
Definition
----------
-===================================== ===================================== =====================================
-:Author: Clément Maria :Since: GUDHI PYTHON 2.0.0 :License: GPL v3
-===================================== ===================================== =====================================
-
-+-----------------------------------------------------------------+-----------------------------------------------------------------------+
-| :doc:`persistent_cohomology_user` | Please refer to each data structure that contains persistence |
-| | feature for reference: |
-| | |
-| | * :doc:`simplex_tree_ref` |
-| | * :doc:`cubical_complex_ref` |
-| | * :doc:`periodic_cubical_complex_ref` |
-+-----------------------------------------------------------------+-----------------------------------------------------------------------+
+.. list-table::
+ :width: 100%
+ :header-rows: 0
+
+ * - :Author: Clément Maria
+ - :Since: GUDHI 2.0.0
+ - :License: MIT
+
+.. list-table::
+ :width: 100%
+ :header-rows: 0
+
+ * - :doc:`persistent_cohomology_user`
+ - Please refer to each data structure that contains persistence feature for reference:
+ * :doc:`simplex_tree_ref`
+ * :doc:`cubical_complex_ref`
+ * :doc:`periodic_cubical_complex_ref`
Computation of persistent cohomology using the algorithm of :cite:`DBLP:journals/dcg/SilvaMV11` and
:cite:`DBLP:conf/compgeom/DeyFW14` and the Compressed Annotation Matrix implementation of
diff --git a/src/python/doc/rips_complex_sum.inc b/src/python/doc/rips_complex_sum.inc
index 2cb24990..2b125e54 100644
--- a/src/python/doc/rips_complex_sum.inc
+++ b/src/python/doc/rips_complex_sum.inc
@@ -11,4 +11,9 @@
| | | |
+----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
| * :doc:`rips_complex_user` | * :doc:`rips_complex_ref` |
- +----------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------+
+ +----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
+ | .. image:: | * :doc:`rips_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
+ | img/tensorflow.png | | |
+ | :target: https://www.tensorflow.org | | |
+ | :height: 30 | | |
+ +----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
diff --git a/src/python/doc/rips_complex_tflow_itf_ref.rst b/src/python/doc/rips_complex_tflow_itf_ref.rst
new file mode 100644
index 00000000..3ce75868
--- /dev/null
+++ b/src/python/doc/rips_complex_tflow_itf_ref.rst
@@ -0,0 +1,48 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+TensorFlow layer for Vietoris-Rips persistence
+##############################################
+
+.. include:: differentiation_sum.inc
+
+Example of gradient computed from Vietoris-Rips persistence
+-----------------------------------------------------------
+
+.. testsetup::
+
+ import numpy
+ numpy.set_printoptions(precision=4)
+
+.. testcode::
+
+ from gudhi.tensorflow import RipsLayer
+ import tensorflow as tf
+
+ X = tf.Variable([[1.,1.],[2.,2.]], dtype=tf.float32, trainable=True)
+ rl = RipsLayer(maximum_edge_length=2., homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = rl.call(X)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+
+ grads = tape.gradient(loss, [X])
+ print(grads[0].numpy())
+
+.. testcleanup::
+
+ numpy.set_printoptions(precision=8)
+
+.. testoutput::
+
+ [[-0.5 -0.5]
+ [ 0.5 0.5]]
+
+Documentation for RipsLayer
+---------------------------
+
+.. autoclass:: gudhi.tensorflow.RipsLayer
+ :members:
+ :special-members: __init__
+ :show-inheritance:
diff --git a/src/python/doc/rips_complex_user.rst b/src/python/doc/rips_complex_user.rst
index 27d218d4..c41a7803 100644
--- a/src/python/doc/rips_complex_user.rst
+++ b/src/python/doc/rips_complex_user.rst
@@ -7,13 +7,7 @@ Rips complex user manual
Definition
----------
-================================================================================ ================================ ======================
-:Authors: Clément Maria, Pawel Dlotko, Vincent Rouvreau, Marc Glisse, Yuichi Ike :Since: GUDHI 2.0.0 :License: GPL v3
-================================================================================ ================================ ======================
-
-+-------------------------------------------+----------------------------------------------------------------------+
-| :doc:`rips_complex_user` | :doc:`rips_complex_ref` |
-+-------------------------------------------+----------------------------------------------------------------------+
+.. include:: rips_complex_sum.inc
The `Rips complex <https://en.wikipedia.org/wiki/Vietoris%E2%80%93Rips_complex>`_ is a simplicial complex that
generalizes proximity (:math:`\varepsilon`-ball) graphs to higher dimensions. The vertices correspond to the input
diff --git a/src/python/doc/simplex_tree_sum.inc b/src/python/doc/simplex_tree_sum.inc
index a8858f16..6b534c9e 100644
--- a/src/python/doc/simplex_tree_sum.inc
+++ b/src/python/doc/simplex_tree_sum.inc
@@ -1,13 +1,18 @@
.. table::
:widths: 30 40 30
- +----------------------------------------------------------------+------------------------------------------------------------------------+-----------------------------+
- | .. figure:: | The simplex tree is an efficient and flexible data structure for | :Author: Clément Maria |
- | ../../doc/Simplex_tree/Simplex_tree_representation.png | representing general (filtered) simplicial complexes. | |
- | :alt: Simplex tree representation | | :Since: GUDHI 2.0.0 |
- | :figclass: align-center | The data structure is described in | |
- | | :cite:`boissonnatmariasimplextreealgorithmica` | :License: MIT |
- | | | |
- +----------------------------------------------------------------+------------------------------------------------------------------------+-----------------------------+
- | * :doc:`simplex_tree_user` | * :doc:`simplex_tree_ref` |
- +----------------------------------------------------------------+------------------------------------------------------------------------------------------------------+
+ +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
+ | .. figure:: | The simplex tree is an efficient and flexible data structure for | :Author: Clément Maria |
+ | ../../doc/Simplex_tree/Simplex_tree_representation.png | representing general (filtered) simplicial complexes. | |
+ | :alt: Simplex tree representation | | :Since: GUDHI 2.0.0 |
+ | :figclass: align-center | The data structure is described in | |
+ | | :cite:`boissonnatmariasimplextreealgorithmica` | :License: MIT |
+ | | | |
+ +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
+ | * :doc:`simplex_tree_user` | * :doc:`simplex_tree_ref` |
+ +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
+ | .. image:: | * :doc:`ls_simplex_tree_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
+ | img/tensorflow.png | | |
+ | :target: https://www.tensorflow.org | | |
+ | :height: 30 | | |
+ +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
diff --git a/src/python/example/alpha_complex_diagram_persistence_from_off_file_example.py b/src/python/example/alpha_complex_diagram_persistence_from_off_file_example.py
index fe03be31..c96121a6 100755
--- a/src/python/example/alpha_complex_diagram_persistence_from_off_file_example.py
+++ b/src/python/example/alpha_complex_diagram_persistence_from_off_file_example.py
@@ -1,9 +1,7 @@
#!/usr/bin/env python
import argparse
-import errno
-import os
-import gudhi
+import gudhi as gd
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ -
which is released under MIT.
@@ -41,33 +39,24 @@ parser.add_argument(
args = parser.parse_args()
-with open(args.file, "r") as f:
- first_line = f.readline()
- if (first_line == "OFF\n") or (first_line == "nOFF\n"):
- print("##############################################################")
- print("AlphaComplex creation from points read in a OFF file")
-
- alpha_complex = gudhi.AlphaComplex(off_file=args.file)
- if args.max_alpha_square is not None:
- print("with max_edge_length=", args.max_alpha_square)
- simplex_tree = alpha_complex.create_simplex_tree(
- max_alpha_square=args.max_alpha_square
- )
- else:
- simplex_tree = alpha_complex.create_simplex_tree()
-
- print("Number of simplices=", simplex_tree.num_simplices())
-
- diag = simplex_tree.persistence()
-
- print("betti_numbers()=", simplex_tree.betti_numbers())
-
- if args.no_diagram == False:
- import matplotlib.pyplot as plot
- gudhi.plot_persistence_diagram(diag, band=args.band)
- plot.show()
- else:
- raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
- args.file)
-
- f.close()
+print("##############################################################")
+print("AlphaComplex creation from points read in a OFF file")
+
+points = gd.read_points_from_off_file(off_file = args.file)
+alpha_complex = gd.AlphaComplex(points = points)
+if args.max_alpha_square is not None:
+ print("with max_edge_length=", args.max_alpha_square)
+ simplex_tree = alpha_complex.create_simplex_tree(
+ max_alpha_square=args.max_alpha_square
+ )
+else:
+ simplex_tree = alpha_complex.create_simplex_tree()
+
+print("Number of simplices=", simplex_tree.num_simplices())
+
+diag = simplex_tree.persistence()
+print("betti_numbers()=", simplex_tree.betti_numbers())
+if args.no_diagram == False:
+ import matplotlib.pyplot as plot
+ gd.plot_persistence_diagram(diag, band=args.band)
+ plot.show()
diff --git a/src/python/example/alpha_rips_persistence_bottleneck_distance.py b/src/python/example/alpha_rips_persistence_bottleneck_distance.py
index 3e12b0d5..6b97fb3b 100755
--- a/src/python/example/alpha_rips_persistence_bottleneck_distance.py
+++ b/src/python/example/alpha_rips_persistence_bottleneck_distance.py
@@ -1,10 +1,8 @@
#!/usr/bin/env python
-import gudhi
+import gudhi as gd
import argparse
import math
-import errno
-import os
import numpy as np
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ -
@@ -37,70 +35,60 @@ parser.add_argument("-t", "--threshold", type=float, default=0.5)
parser.add_argument("-d", "--max_dimension", type=int, default=1)
args = parser.parse_args()
-with open(args.file, "r") as f:
- first_line = f.readline()
- if (first_line == "OFF\n") or (first_line == "nOFF\n"):
- point_cloud = gudhi.read_points_from_off_file(off_file=args.file)
- print("##############################################################")
- print("RipsComplex creation from points read in a OFF file")
+point_cloud = gd.read_points_from_off_file(off_file=args.file)
+print("##############################################################")
+print("RipsComplex creation from points read in a OFF file")
- message = "RipsComplex with max_edge_length=" + repr(args.threshold)
- print(message)
+message = "RipsComplex with max_edge_length=" + repr(args.threshold)
+print(message)
- rips_complex = gudhi.RipsComplex(
- points=point_cloud, max_edge_length=args.threshold
- )
-
- rips_stree = rips_complex.create_simplex_tree(
- max_dimension=args.max_dimension)
-
- message = "Number of simplices=" + repr(rips_stree.num_simplices())
- print(message)
-
- rips_stree.compute_persistence()
-
- print("##############################################################")
- print("AlphaComplex creation from points read in a OFF file")
-
- message = "AlphaComplex with max_edge_length=" + repr(args.threshold)
- print(message)
-
- alpha_complex = gudhi.AlphaComplex(points=point_cloud)
- alpha_stree = alpha_complex.create_simplex_tree(
- max_alpha_square=(args.threshold * args.threshold)
- )
-
- message = "Number of simplices=" + repr(alpha_stree.num_simplices())
- print(message)
+rips_complex = gd.RipsComplex(
+ points=point_cloud, max_edge_length=args.threshold
+)
- alpha_stree.compute_persistence()
+rips_stree = rips_complex.create_simplex_tree(
+ max_dimension=args.max_dimension)
- max_b_distance = 0.0
- for dim in range(args.max_dimension):
- # Alpha persistence values needs to be transform because filtration
- # values are alpha square values
- alpha_intervals = np.sqrt(alpha_stree.persistence_intervals_in_dimension(dim))
+message = "Number of simplices=" + repr(rips_stree.num_simplices())
+print(message)
- rips_intervals = rips_stree.persistence_intervals_in_dimension(dim)
- bottleneck_distance = gudhi.bottleneck_distance(
- rips_intervals, alpha_intervals
- )
- message = (
- "In dimension "
- + repr(dim)
- + ", bottleneck distance = "
- + repr(bottleneck_distance)
- )
- print(message)
- max_b_distance = max(bottleneck_distance, max_b_distance)
+rips_stree.compute_persistence()
- print("==============================================================")
- message = "Bottleneck distance is " + repr(max_b_distance)
- print(message)
+print("##############################################################")
+print("AlphaComplex creation from points read in a OFF file")
- else:
- raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
- args.file)
+message = "AlphaComplex with max_edge_length=" + repr(args.threshold)
+print(message)
+alpha_complex = gd.AlphaComplex(points=point_cloud)
+alpha_stree = alpha_complex.create_simplex_tree(
+ max_alpha_square=(args.threshold * args.threshold)
+)
- f.close()
+message = "Number of simplices=" + repr(alpha_stree.num_simplices())
+print(message)
+
+alpha_stree.compute_persistence()
+
+max_b_distance = 0.0
+for dim in range(args.max_dimension):
+ # Alpha persistence values needs to be transform because filtration
+ # values are alpha square values
+ alpha_intervals = np.sqrt(alpha_stree.persistence_intervals_in_dimension(dim))
+
+ rips_intervals = rips_stree.persistence_intervals_in_dimension(dim)
+ bottleneck_distance = gd.bottleneck_distance(
+ rips_intervals, alpha_intervals
+ )
+ message = (
+ "In dimension "
+ + repr(dim)
+ + ", bottleneck distance = "
+ + repr(bottleneck_distance)
+ )
+ print(message)
+ max_b_distance = max(bottleneck_distance, max_b_distance)
+
+print("==============================================================")
+message = "Bottleneck distance is " + repr(max_b_distance)
+print(message)
diff --git a/src/python/example/plot_alpha_complex.py b/src/python/example/plot_alpha_complex.py
index 99c18a7c..0924619b 100755
--- a/src/python/example/plot_alpha_complex.py
+++ b/src/python/example/plot_alpha_complex.py
@@ -1,8 +1,9 @@
#!/usr/bin/env python
import numpy as np
-import gudhi
-ac = gudhi.AlphaComplex(off_file='../../data/points/tore3D_1307.off')
+import gudhi as gd
+points = gd.read_points_from_off_file(off_file = '../../data/points/tore3D_1307.off')
+ac = gd.AlphaComplex(points = points)
st = ac.create_simplex_tree()
points = np.array([ac.get_point(i) for i in range(st.num_vertices())])
# We want to plot the alpha-complex with alpha=0.1.
diff --git a/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py b/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py
index ea2eb7e1..0b35dbc5 100755
--- a/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py
+++ b/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py
@@ -40,7 +40,7 @@ parser.add_argument(
args = parser.parse_args()
if not (-1.0 < args.min_edge_correlation < 1.0):
- print("Wrong value of the treshold corelation (should be between -1 and 1).")
+ print("Wrong value of the threshold corelation (should be between -1 and 1).")
sys.exit(1)
print("#####################################################################")
diff --git a/src/python/gudhi/__init__.py.in b/src/python/gudhi/__init__.py.in
index 3043201a..79e12fbc 100644
--- a/src/python/gudhi/__init__.py.in
+++ b/src/python/gudhi/__init__.py.in
@@ -23,10 +23,6 @@ __all__ = [@GUDHI_PYTHON_MODULES@ @GUDHI_PYTHON_MODULES_EXTRA@]
__available_modules = ''
__missing_modules = ''
-# For unitary tests purpose
-# could use "if 'collapse_edges' in gudhi.__all__" when collapse edges will have a python module
-__GUDHI_USE_EIGEN3 = @GUDHI_USE_EIGEN3@
-
# Try to import * from gudhi.__module_name for default modules.
# Extra modules require an explicit import by the user (mostly because of
# unusual dependencies, but also to avoid cluttering namespace gudhi and
diff --git a/src/python/gudhi/alpha_complex.pyx b/src/python/gudhi/alpha_complex.pyx
index ea128743..375e1561 100644
--- a/src/python/gudhi/alpha_complex.pyx
+++ b/src/python/gudhi/alpha_complex.pyx
@@ -16,7 +16,7 @@ from libcpp.utility cimport pair
from libcpp.string cimport string
from libcpp cimport bool
from libc.stdint cimport intptr_t
-import os
+import warnings
from gudhi.simplex_tree cimport *
from gudhi.simplex_tree import SimplexTree
@@ -28,66 +28,76 @@ __license__ = "GPL v3"
cdef extern from "Alpha_complex_interface.h" namespace "Gudhi":
cdef cppclass Alpha_complex_interface "Gudhi::alpha_complex::Alpha_complex_interface":
- Alpha_complex_interface(vector[vector[double]] points, bool fast_version, bool exact_version) nogil except +
+ Alpha_complex_interface(vector[vector[double]] points, vector[double] weights, bool fast_version, bool exact_version) nogil except +
vector[double] get_point(int vertex) nogil except +
void create_simplex_tree(Simplex_tree_interface_full_featured* simplex_tree, double max_alpha_square, bool default_filtration_value) nogil except +
+ @staticmethod
+ void set_float_relative_precision(double precision) nogil
+ @staticmethod
+ double get_float_relative_precision() nogil
# AlphaComplex python interface
cdef class AlphaComplex:
- """AlphaComplex is a simplicial complex constructed from the finite cells
- of a Delaunay Triangulation.
+ """AlphaComplex is a simplicial complex constructed from the finite cells of a Delaunay Triangulation.
- The filtration value of each simplex is computed as the square of the
- circumradius of the simplex if the circumsphere is empty (the simplex is
- then said to be Gabriel), and as the minimum of the filtration values of
- the codimension 1 cofaces that make it not Gabriel otherwise.
+ The filtration value of each simplex is computed as the square of the circumradius of the simplex if the
+ circumsphere is empty (the simplex is then said to be Gabriel), and as the minimum of the filtration values of the
+ codimension 1 cofaces that make it not Gabriel otherwise.
- All simplices that have a filtration value strictly greater than a given
- alpha squared value are not inserted into the complex.
+ All simplices that have a filtration value strictly greater than a given alpha squared value are not inserted into
+ the complex.
.. note::
- When Alpha_complex is constructed with an infinite value of alpha, the
- complex is a Delaunay complex.
-
+ When Alpha_complex is constructed with an infinite value of alpha, the complex is a Delaunay complex.
"""
cdef Alpha_complex_interface * this_ptr
# Fake constructor that does nothing but documenting the constructor
- def __init__(self, points=None, off_file='', precision='safe'):
+ def __init__(self, points=[], off_file='', weights=None, precision='safe'):
"""AlphaComplex constructor.
:param points: A list of points in d-Dimension.
- :type points: list of list of double
-
- Or
+ :type points: Iterable[Iterable[float]]
- :param off_file: An OFF file style name.
+ :param off_file: **[deprecated]** An `OFF file style <fileformats.html#off-file-format>`_ name.
+ If an `off_file` is given with `points` as arguments, only points from the file are taken into account.
:type off_file: string
+ :param weights: A list of weights. If set, the number of weights must correspond to the number of points.
+ :type weights: Iterable[float]
+
:param precision: Alpha complex precision can be 'fast', 'safe' or 'exact'. Default is 'safe'.
:type precision: string
+
+ :raises FileNotFoundError: **[deprecated]** If `off_file` is set but not found.
+ :raises ValueError: In case of inconsistency between the number of points and weights.
"""
# The real cython constructor
- def __cinit__(self, points = None, off_file = '', precision = 'safe'):
+ def __cinit__(self, points = [], off_file = '', weights=None, precision = 'safe'):
assert precision in ['fast', 'safe', 'exact'], "Alpha complex precision can only be 'fast', 'safe' or 'exact'"
cdef bool fast = precision == 'fast'
cdef bool exact = precision == 'exact'
- cdef vector[vector[double]] pts
if off_file:
- if os.path.isfile(off_file):
- points = read_points_from_off_file(off_file = off_file)
- else:
- print("file " + off_file + " not found.")
- if points is None:
- # Empty Alpha construction
- points=[]
+ warnings.warn("off_file is a deprecated parameter, please consider using gudhi.read_points_from_off_file",
+ DeprecationWarning)
+ points = read_points_from_off_file(off_file = off_file)
+
+ # weights are set but is inconsistent with the number of points
+ if weights != None and len(weights) != len(points):
+ raise ValueError("Inconsistency between the number of points and weights")
+
+ # need to copy the points to use them without the gil
+ cdef vector[vector[double]] pts
+ cdef vector[double] wgts
pts = points
+ if weights != None:
+ wgts = weights
with nogil:
- self.this_ptr = new Alpha_complex_interface(pts, fast, exact)
+ self.this_ptr = new Alpha_complex_interface(pts, wgts, fast, exact)
def __dealloc__(self):
if self.this_ptr != NULL:
@@ -127,3 +137,28 @@ cdef class AlphaComplex:
self.this_ptr.create_simplex_tree(<Simplex_tree_interface_full_featured*>stree_int_ptr,
mas, compute_filtration)
return stree
+
+ @staticmethod
+ def set_float_relative_precision(precision):
+ """
+ :param precision: When the AlphaComplex is constructed with :code:`precision = 'safe'` (the default),
+ one can set the float relative precision of filtration values computed in
+ :func:`~gudhi.AlphaComplex.create_simplex_tree`.
+ Default is :code:`1e-5` (cf. :func:`~gudhi.AlphaComplex.get_float_relative_precision`).
+ For more details, please refer to
+ `CGAL::Lazy_exact_nt<NT>::set_relative_precision_of_to_double <https://doc.cgal.org/latest/Number_types/classCGAL_1_1Lazy__exact__nt.html>`_
+ :type precision: float
+ """
+ if precision <=0. or precision >= 1.:
+ raise ValueError("Relative precision value must be strictly greater than 0 and strictly lower than 1")
+ Alpha_complex_interface.set_float_relative_precision(precision)
+
+ @staticmethod
+ def get_float_relative_precision():
+ """
+ :returns: The float relative precision of filtration values computation in
+ :func:`~gudhi.AlphaComplex.create_simplex_tree` when the AlphaComplex is constructed with
+ :code:`precision = 'safe'` (the default).
+ :rtype: float
+ """
+ return Alpha_complex_interface.get_float_relative_precision()
diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py
new file mode 100644
index 00000000..f6d3fe56
--- /dev/null
+++ b/src/python/gudhi/datasets/remote.py
@@ -0,0 +1,223 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Hind Montassif
+#
+# Copyright (C) 2021 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+from os.path import join, split, exists, expanduser
+from os import makedirs, remove, environ
+
+from urllib.request import urlretrieve
+import hashlib
+import shutil
+
+import numpy as np
+
+def _get_data_home(data_home = None):
+ """
+ Return the path of the remote datasets directory.
+ This folder is used to store remotely fetched datasets.
+ By default the datasets directory is set to a folder named 'gudhi_data' in the user home folder.
+ Alternatively, it can be set by the 'GUDHI_DATA' environment variable.
+ The '~' symbol is expanded to the user home folder.
+ If the folder does not already exist, it is automatically created.
+
+ Parameters
+ ----------
+ data_home : string
+ The path to remote datasets directory.
+ Default is `None`, meaning that the data home directory will be set to "~/gudhi_data",
+ if the 'GUDHI_DATA' environment variable does not exist.
+
+ Returns
+ -------
+ data_home: string
+ The path to remote datasets directory.
+ """
+ if data_home is None:
+ data_home = environ.get("GUDHI_DATA", join("~", "gudhi_data"))
+ data_home = expanduser(data_home)
+ makedirs(data_home, exist_ok=True)
+ return data_home
+
+
+def clear_data_home(data_home = None):
+ """
+ Delete the data home cache directory and all its content.
+
+ Parameters
+ ----------
+ data_home : string, default is None.
+ The path to remote datasets directory.
+ If `None` and the 'GUDHI_DATA' environment variable does not exist,
+ the default directory to be removed is set to "~/gudhi_data".
+ """
+ data_home = _get_data_home(data_home)
+ shutil.rmtree(data_home)
+
+def _checksum_sha256(file_path):
+ """
+ Compute the file checksum using sha256.
+
+ Parameters
+ ----------
+ file_path: string
+ Full path of the created file including filename.
+
+ Returns
+ -------
+ The hex digest of file_path.
+ """
+ sha256_hash = hashlib.sha256()
+ chunk_size = 4096
+ with open(file_path,"rb") as f:
+ # Read and update hash string value in blocks of 4K
+ while True:
+ buffer = f.read(chunk_size)
+ if not buffer:
+ break
+ sha256_hash.update(buffer)
+ return sha256_hash.hexdigest()
+
+def _fetch_remote(url, file_path, file_checksum = None):
+ """
+ Fetch the wanted dataset from the given url and save it in file_path.
+
+ Parameters
+ ----------
+ url : string
+ The url to fetch the dataset from.
+ file_path : string
+ Full path of the downloaded file including filename.
+ file_checksum : string
+ The file checksum using sha256 to check against the one computed on the downloaded file.
+ Default is 'None', which means the checksum is not checked.
+
+ Raises
+ ------
+ IOError
+ If the computed SHA256 checksum of file does not match the one given by the user.
+ """
+
+ # Get the file
+ urlretrieve(url, file_path)
+
+ if file_checksum is not None:
+ checksum = _checksum_sha256(file_path)
+ if file_checksum != checksum:
+ # Remove file and raise error
+ remove(file_path)
+ raise IOError("{} has a SHA256 checksum : {}, "
+ "different from expected : {}."
+ "The file may be corrupted or the given url may be wrong !".format(file_path, checksum, file_checksum))
+
+def _get_archive_path(file_path, label):
+ """
+ Get archive path based on file_path given by user and label.
+
+ Parameters
+ ----------
+ file_path: string
+ Full path of the file to get including filename, or None.
+ label: string
+ Label used along with 'data_home' to get archive path, in case 'file_path' is None.
+
+ Returns
+ -------
+ Full path of archive including filename.
+ """
+ if file_path is None:
+ archive_path = join(_get_data_home(), label)
+ dirname = split(archive_path)[0]
+ makedirs(dirname, exist_ok=True)
+ else:
+ archive_path = file_path
+ dirname = split(archive_path)[0]
+ makedirs(dirname, exist_ok=True)
+
+ return archive_path
+
+def fetch_spiral_2d(file_path = None):
+ """
+ Load the spiral_2d dataset.
+
+ Note that if the dataset already exists in the target location, it is not downloaded again,
+ and the corresponding array is returned from cache.
+
+ Parameters
+ ----------
+ file_path : string
+ Full path of the downloaded file including filename.
+
+ Default is None, meaning that it's set to "data_home/points/spiral_2d/spiral_2d.npy".
+
+ The "data_home" directory is set by default to "~/gudhi_data",
+ unless the 'GUDHI_DATA' environment variable is set.
+
+ Returns
+ -------
+ points: numpy array
+ Array of shape (114562, 2).
+ """
+ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy"
+ file_checksum = '2226024da76c073dd2f24b884baefbfd14928b52296df41ad2d9b9dc170f2401'
+
+ archive_path = _get_archive_path(file_path, "points/spiral_2d/spiral_2d.npy")
+
+ if not exists(archive_path):
+ _fetch_remote(file_url, archive_path, file_checksum)
+
+ return np.load(archive_path, mmap_mode='r')
+
+def fetch_bunny(file_path = None, accept_license = False):
+ """
+ Load the Stanford bunny dataset.
+
+ This dataset contains 35947 vertices.
+
+ Note that if the dataset already exists in the target location, it is not downloaded again,
+ and the corresponding array is returned from cache.
+
+ Parameters
+ ----------
+ file_path : string
+ Full path of the downloaded file including filename.
+
+ Default is None, meaning that it's set to "data_home/points/bunny/bunny.npy".
+ In this case, the LICENSE file would be downloaded as "data_home/points/bunny/bunny.LICENSE".
+
+ The "data_home" directory is set by default to "~/gudhi_data",
+ unless the 'GUDHI_DATA' environment variable is set.
+
+ accept_license : boolean
+ Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms.
+
+ Default is False.
+
+ Returns
+ -------
+ points: numpy array
+ Array of shape (35947, 3).
+ """
+
+ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy"
+ file_checksum = 'f382482fd89df8d6444152dc8fd454444fe597581b193fd139725a85af4a6c6e'
+ license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.LICENSE"
+ license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a'
+
+ archive_path = _get_archive_path(file_path, "points/bunny/bunny.npy")
+
+ if not exists(archive_path):
+ _fetch_remote(file_url, archive_path, file_checksum)
+ license_path = join(split(archive_path)[0], "bunny.LICENSE")
+ _fetch_remote(license_url, license_path, license_checksum)
+ # Print license terms unless accept_license is set to True
+ if not accept_license:
+ if exists(license_path):
+ with open(license_path, 'r') as f:
+ print(f.read())
+
+ return np.load(archive_path, mmap_mode='r')
diff --git a/src/python/gudhi/hera/wasserstein.cc b/src/python/gudhi/hera/wasserstein.cc
index 1a21f02f..fa0cf8aa 100644
--- a/src/python/gudhi/hera/wasserstein.cc
+++ b/src/python/gudhi/hera/wasserstein.cc
@@ -29,7 +29,7 @@ double wasserstein_distance(
if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>();
params.internal_p = internal_p;
params.delta = delta;
- // The extra parameters are purposedly not exposed for now.
+ // The extra parameters are purposely not exposed for now.
return hera::wasserstein_dist(diag1, diag2, params);
}
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 848dc03e..21275cdd 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -12,6 +12,9 @@ from os import path
from math import isfinite
import numpy as np
from functools import lru_cache
+import warnings
+import errno
+import os
from gudhi.reader_utils import read_persistence_intervals_in_dimension
from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension
@@ -22,6 +25,7 @@ __license__ = "MIT"
_gudhi_matplotlib_use_tex = True
+
def __min_birth_max_death(persistence, band=0.0):
"""This function returns (min_birth, max_death) from the persistence.
@@ -44,20 +48,46 @@ def __min_birth_max_death(persistence, band=0.0):
min_birth = float(interval[1][0])
if band > 0.0:
max_death += band
+ # can happen if only points at inf death
+ if min_birth == max_death:
+ max_death = max_death + 1.0
return (min_birth, max_death)
def _array_handler(a):
- '''
+ """
:param a: if array, assumes it is a (n x 2) np.array and return a
persistence-compatible list (padding with 0), so that the
plot can be performed seamlessly.
- '''
- if isinstance(a[0][1], np.float64) or isinstance(a[0][1], float):
+ """
+ if isinstance(a[0][1], (np.floating, float)):
return [[0, x] for x in a]
else:
return a
+
+def _limit_to_max_intervals(persistence, max_intervals, key):
+ """This function returns truncated persistence if length is bigger than max_intervals.
+ :param persistence: Persistence intervals values list. Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
+ :param max_intervals: maximal number of intervals to display.
+ Selected intervals are those with the longest life time. Set it
+ to 0 to see all. Default value is 1000.
+ :type max_intervals: int.
+ :param key: key function for sort algorithm.
+ :type key: function or lambda.
+ """
+ if max_intervals > 0 and max_intervals < len(persistence):
+ warnings.warn(
+ "There are %s intervals given as input, whereas max_intervals is set to %s."
+ % (len(persistence), max_intervals)
+ )
+ # Sort by life time, then takes only the max_intervals elements
+ return sorted(persistence, key=key, reverse=True)[:max_intervals]
+ else:
+ return persistence
+
+
@lru_cache(maxsize=1)
def _matplotlib_can_use_tex():
"""This function returns True if matplotlib can deal with LaTeX, False otherwise.
@@ -65,17 +95,17 @@ def _matplotlib_can_use_tex():
"""
try:
from matplotlib import checkdep_usetex
+
return checkdep_usetex(True)
- except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
def plot_persistence_barcode(
persistence=[],
persistence_file="",
alpha=0.6,
- max_intervals=1000,
- max_barcodes=1000,
+ max_intervals=20000,
inf_delta=0.1,
legend=False,
colormap=None,
@@ -97,7 +127,7 @@ def plot_persistence_barcode(
:type alpha: float.
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
- to 0 to see all. Default value is 1000.
+ to 0 to see all. Default value is 20000.
:type max_intervals: int.
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x
inf_delta)` above :code:`max_death` value. A reasonable value is
@@ -119,99 +149,68 @@ def plot_persistence_barcode(
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import rc
+
if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
else:
- plt.rc('text', usetex=False)
- plt.rc('font', family='DejaVu Sans')
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
- diag = read_persistence_intervals_grouped_by_dimension(
- persistence_file=persistence_file
- )
+ diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
for key in diag.keys():
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
else:
- print("file " + persistence_file + " not found.")
- return None
-
- persistence = _array_handler(persistence)
-
- if max_barcodes != 1000:
- print("Deprecated parameter. It has been replaced by max_intervals")
- max_intervals = max_barcodes
-
- if max_intervals > 0 and max_intervals < len(persistence):
- # Sort by life time, then takes only the max_intervals elements
- persistence = sorted(
- persistence,
- key=lambda life_time: life_time[1][1] - life_time[1][0],
- reverse=True,
- )[:max_intervals]
-
- if colormap == None:
- colormap = plt.cm.Set1.colors
- if axes == None:
- fig, axes = plt.subplots(1, 1)
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
- persistence = sorted(persistence, key=lambda birth: birth[1][0])
+ try:
+ persistence = _array_handler(persistence)
+ persistence = _limit_to_max_intervals(
+ persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ (min_birth, max_death) = __min_birth_max_death(persistence)
+ persistence = sorted(persistence, key=lambda birth: birth[1][0])
+ except IndexError:
+ min_birth, max_death = 0.0, 1.0
+ pass
- (min_birth, max_death) = __min_birth_max_death(persistence)
- ind = 0
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for bar code to be more
# readable
infinity = max_death + delta
axis_start = min_birth - delta
- # Draw horizontal bars in loop
- for interval in reversed(persistence):
- if float(interval[1][1]) != float("inf"):
- # Finite death case
- axes.barh(
- ind,
- (interval[1][1] - interval[1][0]),
- height=0.8,
- left=interval[1][0],
- alpha=alpha,
- color=colormap[interval[0]],
- linewidth=0,
- )
- else:
- # Infinite death case for diagram to be nicer
- axes.barh(
- ind,
- (infinity - interval[1][0]),
- height=0.8,
- left=interval[1][0],
- alpha=alpha,
- color=colormap[interval[0]],
- linewidth=0,
- )
- ind = ind + 1
+
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
+
+ x=[birth for (dim,(birth,death)) in persistence]
+ y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence]
+ c=[colormap[dim] for (dim,(birth,death)) in persistence]
+
+ axes.barh(list(reversed(range(len(x)))), y, height=0.8, left=x, alpha=alpha, color=c, linewidth=0)
if legend:
dimensions = list(set(item[0] for item in persistence))
axes.legend(
- handles=[
- mpatches.Patch(color=colormap[dim], label=str(dim))
- for dim in dimensions
- ],
- loc="lower right",
+ handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right",
)
axes.set_title("Persistence barcode", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
- axes.axis([axis_start, infinity, 0, ind])
+ if len(x) != 0:
+ axes.axis([axis_start, infinity, 0, len(x)])
return axes
- except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
def plot_persistence_diagram(
@@ -219,14 +218,13 @@ def plot_persistence_diagram(
persistence_file="",
alpha=0.6,
band=0.0,
- max_intervals=1000,
- max_plots=1000,
+ max_intervals=1000000,
inf_delta=0.1,
legend=False,
colormap=None,
axes=None,
fontsize=16,
- greyblock=True
+ greyblock=True,
):
"""This function plots the persistence diagram from persistence values
list, a np.array of shape (N x 2) representing a diagram in a single
@@ -244,7 +242,7 @@ def plot_persistence_diagram(
:type band: float.
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
- to 0 to see all. Default value is 1000.
+ to 0 to see all. Default value is 1000000.
:type max_intervals: int.
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x
inf_delta)` above :code:`max_death` value. A reasonable value is
@@ -268,47 +266,35 @@ def plot_persistence_diagram(
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import rc
+
if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
else:
- plt.rc('text', usetex=False)
- plt.rc('font', family='DejaVu Sans')
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
- diag = read_persistence_intervals_grouped_by_dimension(
- persistence_file=persistence_file
- )
+ diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
for key in diag.keys():
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
else:
- print("file " + persistence_file + " not found.")
- return None
-
- persistence = _array_handler(persistence)
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
- if max_plots != 1000:
- print("Deprecated parameter. It has been replaced by max_intervals")
- max_intervals = max_plots
-
- if max_intervals > 0 and max_intervals < len(persistence):
- # Sort by life time, then takes only the max_intervals elements
- persistence = sorted(
- persistence,
- key=lambda life_time: life_time[1][1] - life_time[1][0],
- reverse=True,
- )[:max_intervals]
-
- if colormap == None:
- colormap = plt.cm.Set1.colors
- if axes == None:
- fig, axes = plt.subplots(1, 1)
+ try:
+ persistence = _array_handler(persistence)
+ persistence = _limit_to_max_intervals(
+ persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ min_birth, max_death = __min_birth_max_death(persistence, band)
+ except IndexError:
+ min_birth, max_death = 0.0, 1.0
+ pass
- (min_birth, max_death) = __min_birth_max_death(persistence, band)
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for diagram to be more
# readable
@@ -316,61 +302,56 @@ def plot_persistence_diagram(
axis_end = max_death + delta / 2
axis_start = min_birth - delta
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
# bootstrap band
if band > 0.0:
x = np.linspace(axis_start, infinity, 1000)
axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
# lower diag patch
if greyblock:
- axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey'))
- # Draw points in loop
- pts_at_infty = False # Records presence of pts at infty
- for interval in reversed(persistence):
- if float(interval[1][1]) != float("inf"):
- # Finite death case
- axes.scatter(
- interval[1][0],
- interval[1][1],
- alpha=alpha,
- color=colormap[interval[0]],
+ axes.add_patch(
+ mpatches.Polygon(
+ [[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]],
+ fill=True,
+ color="lightgrey",
)
- else:
- pts_at_infty = True
- # Infinite death case for diagram to be nicer
- axes.scatter(
- interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]
- )
- if pts_at_infty:
+ )
+ # line display of equation : birth = death
+ axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
+
+ x=[birth for (dim,(birth,death)) in persistence]
+ y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence]
+ c=[colormap[dim] for (dim,(birth,death)) in persistence]
+
+ axes.scatter(x,y,alpha=alpha,color=c)
+ if float("inf") in (death for (dim,(birth,death)) in persistence):
# infinity line and text
- axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha)
# Infinity label
yt = axes.get_yticks()
- yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity
+ yt = yt[np.where(yt < axis_end)] # to avoid plotting ticklabel higher than infinity
yt = np.append(yt, infinity)
ytl = ["%.3f" % e for e in yt] # to avoid float precision error
- ytl[-1] = r'$+\infty$'
+ ytl[-1] = r"$+\infty$"
axes.set_yticks(yt)
axes.set_yticklabels(ytl)
if legend:
dimensions = list(set(item[0] for item in persistence))
- axes.legend(
- handles=[
- mpatches.Patch(color=colormap[dim], label=str(dim))
- for dim in dimensions
- ]
- )
+ axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions])
axes.set_xlabel("Birth", fontsize=fontsize)
axes.set_ylabel("Death", fontsize=fontsize)
axes.set_title("Persistence diagram", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
- axes.axis([axis_start, axis_end, axis_start, infinity + delta/2])
+ axes.axis([axis_start, axis_end, axis_start, infinity + delta / 2])
return axes
- except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
def plot_persistence_density(
@@ -384,7 +365,7 @@ def plot_persistence_density(
legend=False,
axes=None,
fontsize=16,
- greyblock=False
+ greyblock=False,
):
"""This function plots the persistence density from persistence
values list, np.array of shape (N x 2) representing a diagram
@@ -444,12 +425,13 @@ def plot_persistence_density(
import matplotlib.patches as mpatches
from scipy.stats import kde
from matplotlib import rc
+
if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
else:
- plt.rc('text', usetex=False)
- plt.rc('font', family='DejaVu Sans')
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
if persistence_file != "":
if dimension is None:
@@ -460,10 +442,16 @@ def plot_persistence_density(
persistence_file=persistence_file, only_this_dim=dimension
)
else:
- print("file " + persistence_file + " not found.")
- return None
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
+
+ # default cmap value cannot be done at argument definition level as matplotlib is not yet defined.
+ if cmap is None:
+ cmap = plt.cm.hot_r
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
- if len(persistence) > 0:
+ try:
+ # if not read from file but given by an argument
persistence = _array_handler(persistence)
persistence_dim = np.array(
[
@@ -472,47 +460,54 @@ def plot_persistence_density(
if (dim_interval[0] == dimension) or (dimension is None)
]
)
-
- persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
- if max_intervals > 0 and max_intervals < len(persistence_dim):
- # Sort by life time, then takes only the max_intervals elements
+ persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
persistence_dim = np.array(
- sorted(
- persistence_dim,
- key=lambda life_time: life_time[1] - life_time[0],
- reverse=True,
- )[:max_intervals]
+ _limit_to_max_intervals(
+ persistence_dim, max_intervals, key=lambda life_time: life_time[1] - life_time[0]
+ )
)
- # Set as numpy array birth and death (remove undefined values - inf and NaN)
- birth = persistence_dim[:, 0]
- death = persistence_dim[:, 1]
-
- # default cmap value cannot be done at argument definition level as matplotlib is not yet defined.
- if cmap is None:
- cmap = plt.cm.hot_r
- if axes == None:
- fig, axes = plt.subplots(1, 1)
+ # Set as numpy array birth and death (remove undefined values - inf and NaN)
+ birth = persistence_dim[:, 0]
+ death = persistence_dim[:, 1]
+ birth_min = birth.min()
+ birth_max = birth.max()
+ death_min = death.min()
+ death_max = death.max()
+
+ # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
+ k = kde.gaussian_kde([birth, death], bw_method=bw_method)
+ xi, yi = np.mgrid[
+ birth_min : birth_max : nbins * 1j, death_min : death_max : nbins * 1j,
+ ]
+ zi = k(np.vstack([xi.flatten(), yi.flatten()]))
+ # Make the plot
+ img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto")
+ plot_success = True
+
+ # IndexError on empty diagrams, ValueError on only inf death values
+ except (IndexError, ValueError):
+ birth_min = 0.0
+ birth_max = 1.0
+ death_min = 0.0
+ death_max = 1.0
+ plot_success = False
+ pass
# line display of equation : birth = death
- x = np.linspace(death.min(), birth.max(), 1000)
+ x = np.linspace(death_min, birth_max, 1000)
axes.plot(x, x, color="k", linewidth=1.0)
- # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
- k = kde.gaussian_kde([birth, death], bw_method=bw_method)
- xi, yi = np.mgrid[
- birth.min() : birth.max() : nbins * 1j,
- death.min() : death.max() : nbins * 1j,
- ]
- zi = k(np.vstack([xi.flatten(), yi.flatten()]))
-
- # Make the plot
- img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap)
-
if greyblock:
- axes.add_patch(mpatches.Polygon([[birth.min(), birth.min()], [death.max(), birth.min()], [death.max(), death.max()]], fill=True, color='lightgrey'))
+ axes.add_patch(
+ mpatches.Polygon(
+ [[birth_min, birth_min], [death_max, birth_min], [death_max, death_max]],
+ fill=True,
+ color="lightgrey",
+ )
+ )
- if legend:
+ if plot_success and legend:
plt.colorbar(img, ax=axes)
axes.set_xlabel("Birth", fontsize=fontsize)
@@ -521,7 +516,5 @@ def plot_persistence_density(
return axes
- except ImportError:
- print(
- "This function is not available, you may be missing matplotlib and/or scipy."
- )
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
diff --git a/src/python/gudhi/representations/preprocessing.py b/src/python/gudhi/representations/preprocessing.py
index a8545349..8722e162 100644
--- a/src/python/gudhi/representations/preprocessing.py
+++ b/src/python/gudhi/representations/preprocessing.py
@@ -1,10 +1,11 @@
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
-# Author(s): Mathieu Carrière
+# Author(s): Mathieu Carrière, Vincent Rouvreau
#
# Copyright (C) 2018-2019 Inria
#
# Modification(s):
+# - 2021/10 Vincent Rouvreau: Add DimensionSelector
# - YYYY/MM Author: Description of the modification
import numpy as np
@@ -75,7 +76,7 @@ class Clamping(BaseEstimator, TransformerMixin):
Constructor for the Clamping class.
Parameters:
- limit (double): clamping value (default np.inf).
+ limit (float): clamping value (default np.inf).
"""
self.minimum = minimum
self.maximum = maximum
@@ -234,7 +235,7 @@ class ProminentPoints(BaseEstimator, TransformerMixin):
use (bool): whether to use the class or not (default False).
location (string): either "upper" or "lower" (default "upper"). Whether to keep the points that are far away ("upper") or close ("lower") to the diagonal.
num_pts (int): cardinality threshold (default 10). If location == "upper", keep the top **num_pts** points that are the farthest away from the diagonal. If location == "lower", keep the top **num_pts** points that are the closest to the diagonal.
- threshold (double): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal.
+ threshold (float): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal.
"""
self.num_pts = num_pts
self.threshold = threshold
@@ -317,7 +318,7 @@ class DiagramSelector(BaseEstimator, TransformerMixin):
Parameters:
use (bool): whether to use the class or not (default False).
- limit (double): second coordinate value that is the criterion for being an essential point (default numpy.inf).
+ limit (float): second coordinate value that is the criterion for being an essential point (default numpy.inf).
point_type (string): either "finite" or "essential". The type of the points that are going to be extracted.
"""
self.use, self.limit, self.point_type = use, limit, point_type
@@ -363,3 +364,51 @@ class DiagramSelector(BaseEstimator, TransformerMixin):
n x 2 numpy array: extracted persistence diagram.
"""
return self.fit_transform([diag])[0]
+
+
+# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
+# sequenceDiagram
+# USER->>DimensionSelector: fit_transform(<br/>[[array( Hi(X0) ), array( Hj(X0) ), ...],<br/> [array( Hi(X1) ), array( Hj(X1) ), ...],<br/> ...])
+# DimensionSelector->>thread1: _transform([array( Hi(X0) ), array( Hj(X0) )], ...)
+# DimensionSelector->>thread2: _transform([array( Hi(X1) ), array( Hj(X1) )], ...)
+# Note right of DimensionSelector: ...
+# thread1->>DimensionSelector: array( Hn(X0) )
+# thread2->>DimensionSelector: array( Hn(X1) )
+# Note right of DimensionSelector: ...
+# DimensionSelector->>USER: [array( Hn(X0) ), <br/> array( Hn(X1) ), <br/> ...]
+
+class DimensionSelector(BaseEstimator, TransformerMixin):
+ """
+ This is a class to select persistence diagrams in a specific dimension from its index.
+ """
+
+ def __init__(self, index=0):
+ """
+ Constructor for the DimensionSelector class.
+
+ Parameters:
+ index (int): The returned persistence diagrams dimension index. Default value is `0`.
+ """
+ self.index = index
+
+ def fit(self, X, Y=None):
+ """
+ Nothing to be done, but useful when included in a scikit-learn Pipeline.
+ """
+ return self
+
+ def transform(self, X, Y=None):
+ """
+ Select persistence diagrams from its dimension.
+
+ Parameters:
+ X (list of list of tuple): List of list of persistence pairs, i.e.
+ `[[array( Hi(X0) ), array( Hj(X0) ), ...], [array( Hi(X1) ), array( Hj(X1) ), ...], ...]`
+
+ Returns:
+ list of tuple:
+ Persistence diagrams in a specific dimension. i.e. if `index` was set to `m` and `Hn` is at index `m` of
+ the input, it returns `[array( Hn(X0) ), array( Hn(X1), ...]`
+ """
+
+ return [persistence[self.index] for persistence in X]
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index e883b5dd..69ff5e1e 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -1,15 +1,17 @@
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
-# Author(s): Mathieu Carrière, Martin Royer
+# Author(s): Mathieu Carrière, Martin Royer, Gard Spreemann
#
# Copyright (C) 2018-2020 Inria
#
# Modification(s):
# - 2020/06 Martin: ATOL integration
+# - 2020/12 Gard: A more flexible Betti curve class capable of computing exact curves.
# - 2021/11 Vincent Rouvreau: factorize _automatic_sample_range
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
+from sklearn.exceptions import NotFittedError
from sklearn.preprocessing import MinMaxScaler, MaxAbsScaler
from sklearn.neighbors import DistanceMetric
from sklearn.metrics import pairwise
@@ -306,67 +308,162 @@ class Silhouette(BaseEstimator, TransformerMixin):
"""
return self.fit_transform([diag])[0,:]
+
class BettiCurve(BaseEstimator, TransformerMixin):
"""
- This is a class for computing Betti curves from a list of persistence diagrams. A Betti curve is a 1D piecewise-constant function obtained from the rank function. It is sampled evenly on a given range and the vector of samples is returned. See https://www.researchgate.net/publication/316604237_Time_Series_Classification_via_Topological_Data_Analysis for more details.
+ Compute Betti curves from persistence diagrams. There are several modes of operation: with a given resolution (with or without a sample_range), with a predefined grid, and with none of the previous. With a predefined grid, the class computes the Betti numbers at those grid points. Without a predefined grid, if the resolution is set to None, it can be fit to a list of persistence diagrams and produce a grid that consists of (at least) the filtration values at which at least one of those persistence diagrams changes Betti numbers, and then compute the Betti numbers at those grid points. In the latter mode, the exact Betti curve is computed for the entire real line. Otherwise, if the resolution is given, the Betti curve is obtained by sampling evenly using either the given sample_range or based on the persistence diagrams.
"""
- def __init__(self, resolution=100, sample_range=[np.nan, np.nan]):
+
+ def __init__(self, resolution=100, sample_range=[np.nan, np.nan], predefined_grid=None):
"""
Constructor for the BettiCurve class.
Parameters:
resolution (int): number of sample for the piecewise-constant function (default 100).
sample_range ([double, double]): minimum and maximum of the piecewise-constant function domain, of the form [x_min, x_max] (default [numpy.nan, numpy.nan]). It is the interval on which samples will be drawn evenly. If one of the values is numpy.nan, it can be computed from the persistence diagrams with the fit() method.
+ predefined_grid (1d array or None, default=None): Predefined filtration grid points at which to compute the Betti curves. Must be strictly ordered. Infinities are ok. If None (default), and resolution is given, the grid will be uniform from x_min to x_max in 'resolution' steps, otherwise a grid will be computed that captures all changes in Betti numbers in the provided data.
+
+ Attributes:
+ grid_ (1d array): The grid on which the Betti numbers are computed. If predefined_grid was specified, `grid_` will always be that grid, independently of data. If not, the grid is fitted to capture all filtration values at which the Betti numbers change.
+
+ Examples
+ --------
+ If pd is a persistence diagram and xs is a nonempty grid of finite values such that xs[0] >= pd.min(), then the results of:
+
+ >>> bc = BettiCurve(predefined_grid=xs) # doctest: +SKIP
+ >>> result = bc(pd) # doctest: +SKIP
+
+ and
+
+ >>> from scipy.interpolate import interp1d # doctest: +SKIP
+ >>> bc = BettiCurve(resolution=None, predefined_grid=None) # doctest: +SKIP
+ >>> bettis = bc.fit_transform([pd]) # doctest: +SKIP
+ >>> interp = interp1d(bc.grid_, bettis[0, :], kind="previous", fill_value="extrapolate") # doctest: +SKIP
+ >>> result = np.array(interp(xs), dtype=int) # doctest: +SKIP
+
+ are the same.
"""
- self.resolution, self.sample_range = resolution, sample_range
- def fit(self, X, y=None):
+ if (predefined_grid is not None) and (not isinstance(predefined_grid, np.ndarray)):
+ raise ValueError("Expected predefined_grid as array or None.")
+
+ self.predefined_grid = predefined_grid
+ self.resolution = resolution
+ self.sample_range = sample_range
+
+ def is_fitted(self):
+ return hasattr(self, "grid_")
+
+ def fit(self, X, y = None):
"""
- Fit the BettiCurve class on a list of persistence diagrams: if any of the values in **sample_range** is numpy.nan, replace it with the corresponding value computed on the given list of persistence diagrams.
+ Fit the BettiCurve class on a list of persistence diagrams: if any of the values in **sample_range** is numpy.nan, replace it with the corresponding value computed on the given list of persistence diagrams. When no predefined grid is provided and resolution set to None, compute a filtration grid that captures all changes in Betti numbers for all the given persistence diagrams.
Parameters:
- X (list of n x 2 numpy arrays): input persistence diagrams.
- y (n x 1 array): persistence diagram labels (unused).
+ X (list of 2d arrays): Persistence diagrams.
+ y (None): Ignored.
"""
- self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
+
+ if self.predefined_grid is None:
+ if self.resolution is None: # Flexible/exact version
+ events = np.unique(np.concatenate([pd.flatten() for pd in X] + [[-np.inf]], axis=0))
+ self.grid_ = np.array(events)
+ else:
+ self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
+ self.grid_ = np.linspace(self.sample_range[0], self.sample_range[1], self.resolution)
+ else:
+ self.grid_ = self.predefined_grid # Get the predefined grid from user
+
return self
def transform(self, X):
"""
- Compute the Betti curve for each persistence diagram individually and concatenate the results.
+ Compute Betti curves.
Parameters:
- X (list of n x 2 numpy arrays): input persistence diagrams.
-
+ X (list of 2d arrays): Persistence diagrams.
+
Returns:
- numpy array with shape (number of diagrams) x (**resolution**): output Betti curves.
+ `len(X).len(self.grid_)` array of ints: Betti numbers of the given persistence diagrams at the grid points given in `self.grid_`
"""
- Xfit = []
- x_values = np.linspace(self.sample_range[0], self.sample_range[1], self.resolution)
- step_x = x_values[1] - x_values[0]
- for diagram in X:
- diagram_int = np.clip(np.ceil((diagram[:,:2] - self.sample_range[0]) / step_x), 0, self.resolution).astype(int)
- bc = np.zeros(self.resolution)
- for interval in diagram_int:
- bc[interval[0]:interval[1]] += 1
- Xfit.append(np.reshape(bc,[1,-1]))
+ if not self.is_fitted():
+ raise NotFittedError("Not fitted.")
- Xfit = np.concatenate(Xfit, 0)
+ if not X:
+ X = [np.zeros((0, 2))]
+
+ N = len(X)
- return Xfit
+ events = np.concatenate([pd.flatten(order="F") for pd in X], axis=0)
+ sorting = np.argsort(events)
+ offsets = np.zeros(1 + N, dtype=int)
+ for i in range(0, N):
+ offsets[i+1] = offsets[i] + 2*X[i].shape[0]
+ starts = offsets[0:N]
+ ends = offsets[1:N + 1] - 1
- def __call__(self, diag):
+ bettis = [[0] for i in range(0, N)]
+
+ i = 0
+ for x in self.grid_:
+ while i < len(sorting) and events[sorting[i]] <= x:
+ j = np.searchsorted(ends, sorting[i])
+ delta = 1 if sorting[i] - starts[j] < len(X[j]) else -1
+ bettis[j][-1] += delta
+ i += 1
+ for k in range(0, N):
+ bettis[k].append(bettis[k][-1])
+
+ return np.array(bettis, dtype=int)[:, 0:-1]
+
+ def fit_transform(self, X):
+ """
+ The result is the same as fit(X) followed by transform(X), but potentially faster.
"""
- Apply BettiCurve on a single persistence diagram and outputs the result.
- Parameters:
- diag (n x 2 numpy array): input persistence diagram.
+ if self.predefined_grid is None and self.resolution is None:
+ if not X:
+ X = [np.zeros((0, 2))]
- Returns:
- numpy array with shape (**resolution**): output Betti curve.
+ N = len(X)
+
+ events = np.concatenate([pd.flatten(order="F") for pd in X], axis=0)
+ sorting = np.argsort(events)
+ offsets = np.zeros(1 + N, dtype=int)
+ for i in range(0, N):
+ offsets[i+1] = offsets[i] + 2*X[i].shape[0]
+ starts = offsets[0:N]
+ ends = offsets[1:N + 1] - 1
+
+ xs = [-np.inf]
+ bettis = [[0] for i in range(0, N)]
+
+ for i in sorting:
+ j = np.searchsorted(ends, i)
+ delta = 1 if i - starts[j] < len(X[j]) else -1
+ if events[i] == xs[-1]:
+ bettis[j][-1] += delta
+ else:
+ xs.append(events[i])
+ for k in range(0, j):
+ bettis[k].append(bettis[k][-1])
+ bettis[j].append(bettis[j][-1] + delta)
+ for k in range(j+1, N):
+ bettis[k].append(bettis[k][-1])
+
+ self.grid_ = np.array(xs)
+ return np.array(bettis, dtype=int)
+
+ else:
+ return self.fit(X).transform(X)
+
+ def __call__(self, diag):
"""
- return self.fit_transform([diag])[0,:]
+ Shorthand for transform on a single persistence diagram.
+ """
+ return self.fit_transform([diag])[0, :]
+
+
class Entropy(BaseEstimator, TransformerMixin):
"""
@@ -411,26 +508,20 @@ class Entropy(BaseEstimator, TransformerMixin):
new_X = BirthPersistenceTransform().fit_transform(X)
for i in range(num_diag):
- orig_diagram, diagram, num_pts_in_diag = X[i], new_X[i], X[i].shape[0]
- try:
- new_diagram = DiagramScaler(use=True, scalers=[([1], MaxAbsScaler())]).fit_transform([diagram])[0]
- except ValueError:
- # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
- assert len(diagram) == 0
- new_diagram = np.empty(shape = [0, 2])
-
+ orig_diagram, new_diagram, num_pts_in_diag = X[i], new_X[i], X[i].shape[0]
+
+ p = new_diagram[:,1]
+ p = p/np.sum(p)
if self.mode == "scalar":
- ent = - np.sum( np.multiply(new_diagram[:,1], np.log(new_diagram[:,1])) )
+ ent = -np.dot(p, np.log(p))
Xfit.append(np.array([[ent]]))
-
else:
ent = np.zeros(self.resolution)
for j in range(num_pts_in_diag):
[px,py] = orig_diagram[j,:2]
min_idx = np.clip(np.ceil((px - self.sample_range[0]) / step_x).astype(int), 0, self.resolution)
max_idx = np.clip(np.ceil((py - self.sample_range[0]) / step_x).astype(int), 0, self.resolution)
- for k in range(min_idx, max_idx):
- ent[k] += (-1) * new_diagram[j,1] * np.log(new_diagram[j,1])
+ ent[min_idx:max_idx]-=p[j]*np.log(p[j])
if self.normalized:
ent = ent / np.linalg.norm(ent, ord=1)
Xfit.append(np.reshape(ent,[1,-1]))
diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd
index 006a24ed..5642f82d 100644
--- a/src/python/gudhi/simplex_tree.pxd
+++ b/src/python/gudhi/simplex_tree.pxd
@@ -45,6 +45,7 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi":
cdef cppclass Simplex_tree_interface_full_featured "Gudhi::Simplex_tree_interface<Gudhi::Simplex_tree_options_full_featured>":
Simplex_tree_interface_full_featured() nogil
+ Simplex_tree_interface_full_featured(Simplex_tree_interface_full_featured&) nogil
double simplex_filtration(vector[int] simplex) nogil
void assign_simplex_filtration(vector[int] simplex, double filtration) nogil
void initialize_filtration() nogil
@@ -62,9 +63,9 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi":
bool prune_above_filtration(double filtration) nogil
bool make_filtration_non_decreasing() nogil
void compute_extended_filtration() nogil
- vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(vector[pair[int, pair[double, double]]] dgm, double min_persistence) nogil
Simplex_tree_interface_full_featured* collapse_edges(int nb_collapse_iteration) nogil except +
void reset_filtration(double filtration, int dimension) nogil
+ bint operator==(Simplex_tree_interface_full_featured) nogil
# Iterators over Simplex tree
pair[vector[int], double] get_simplex_and_filtration(Simplex_tree_simplex_handle f_simplex) nogil
Simplex_tree_simplices_iterator get_simplices_iterator_begin() nogil
@@ -74,9 +75,12 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi":
Simplex_tree_skeleton_iterator get_skeleton_iterator_begin(int dimension) nogil
Simplex_tree_skeleton_iterator get_skeleton_iterator_end(int dimension) nogil
pair[Simplex_tree_boundary_iterator, Simplex_tree_boundary_iterator] get_boundary_iterators(vector[int] simplex) nogil except +
+ # Expansion with blockers
+ ctypedef bool (*blocker_func_t)(vector[int], void *user_data)
+ void expansion_with_blockers_callback(int dimension, blocker_func_t user_func, void *user_data)
cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi":
- cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface<Gudhi::Simplex_tree<Gudhi::Simplex_tree_options_full_featured>>":
+ cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface<Gudhi::Simplex_tree_interface<Gudhi::Simplex_tree_options_full_featured>>":
Simplex_tree_persistence_interface(Simplex_tree_interface_full_featured * st, bool persistence_dim_max) nogil
void compute_persistence(int homology_coeff_field, double min_persistence) nogil except +
vector[pair[int, pair[double, double]]] get_persistence() nogil
@@ -87,3 +91,4 @@ cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi":
vector[pair[vector[int], vector[int]]] persistence_pairs() nogil
pair[vector[vector[int]], vector[vector[int]]] lower_star_generators() nogil
pair[vector[vector[int]], vector[vector[int]]] flag_generators() nogil
+ vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(double min_persistence) nogil
diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx
index c3720936..05bfe22e 100644
--- a/src/python/gudhi/simplex_tree.pyx
+++ b/src/python/gudhi/simplex_tree.pyx
@@ -16,6 +16,9 @@ __author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2016 Inria"
__license__ = "MIT"
+cdef bool callback(vector[int] simplex, void *blocker_func):
+ return (<object>blocker_func)(simplex)
+
# SimplexTree python interface
cdef class SimplexTree:
"""The simplex tree is an efficient and flexible data structure for
@@ -38,13 +41,29 @@ cdef class SimplexTree:
cdef Simplex_tree_persistence_interface * pcohptr
# Fake constructor that does nothing but documenting the constructor
- def __init__(self):
+ def __init__(self, other = None):
"""SimplexTree constructor.
+
+ :param other: If `other` is `None` (default value), an empty `SimplexTree` is created.
+ If `other` is a `SimplexTree`, the `SimplexTree` is constructed from a deep copy of `other`.
+ :type other: SimplexTree (Optional)
+ :returns: An empty or a copy simplex tree.
+ :rtype: SimplexTree
+
+ :raises TypeError: In case `other` is neither `None`, nor a `SimplexTree`.
+ :note: If the `SimplexTree` is a copy, the persistence information is not copied. If you need it in the clone,
+ you have to call :func:`compute_persistence` on it even if you had already computed it in the original.
"""
# The real cython constructor
- def __cinit__(self):
- self.thisptr = <intptr_t>(new Simplex_tree_interface_full_featured())
+ def __cinit__(self, other = None):
+ if other:
+ if isinstance(other, SimplexTree):
+ self.thisptr = _get_copy_intptr(other)
+ else:
+ raise TypeError("`other` argument requires to be of type `SimplexTree`, or `None`.")
+ else:
+ self.thisptr = <intptr_t>(new Simplex_tree_interface_full_featured())
def __dealloc__(self):
cdef Simplex_tree_interface_full_featured* ptr = self.get_ptr()
@@ -63,6 +82,21 @@ cdef class SimplexTree:
"""
return self.pcohptr != NULL
+ def copy(self):
+ """
+ :returns: A simplex tree that is a deep copy of itself.
+ :rtype: SimplexTree
+
+ :note: The persistence information is not copied. If you need it in the clone, you have to call
+ :func:`compute_persistence` on it even if you had already computed it in the original.
+ """
+ stree = SimplexTree()
+ stree.thisptr = _get_copy_intptr(self)
+ return stree
+
+ def __deepcopy__(self):
+ return self.copy()
+
def filtration(self, simplex):
"""This function returns the filtration value for a given N-simplex in
this simplicial complex, or +infinity if it is not in the complex.
@@ -440,9 +474,29 @@ cdef class SimplexTree:
del self.pcohptr
self.pcohptr = new Simplex_tree_persistence_interface(self.get_ptr(), False)
self.pcohptr.compute_persistence(homology_coeff_field, -1.)
- persistence_result = self.pcohptr.get_persistence()
- return self.get_ptr().compute_extended_persistence_subdiagrams(persistence_result, min_persistence)
+ return self.pcohptr.compute_extended_persistence_subdiagrams(min_persistence)
+
+ def expansion_with_blocker(self, max_dim, blocker_func):
+ """Expands the Simplex_tree containing only a graph. Simplices corresponding to cliques in the graph are added
+ incrementally, faces before cofaces, unless the simplex has dimension larger than `max_dim` or `blocker_func`
+ returns `True` for this simplex.
+
+ The function identifies a candidate simplex whose faces are all already in the complex, inserts it with a
+ filtration value corresponding to the maximum of the filtration values of the faces, then calls `blocker_func`
+ with this new simplex (represented as a list of int). If `blocker_func` returns `True`, the simplex is removed,
+ otherwise it is kept. The algorithm then proceeds with the next candidate.
+ .. warning::
+ Several candidates of the same dimension may be inserted simultaneously before calling `blocker_func`, so
+ if you examine the complex in `blocker_func`, you may hit a few simplices of the same dimension that have
+ not been vetted by `blocker_func` yet, or have already been rejected but not yet removed.
+
+ :param max_dim: Expansion maximal dimension value.
+ :type max_dim: int
+ :param blocker_func: Blocker oracle.
+ :type blocker_func: Callable[[List[int]], bool]
+ """
+ self.get_ptr().expansion_with_blockers_callback(max_dim, callback, <void*>blocker_func)
def persistence(self, homology_coeff_field=11, min_persistence=0, persistence_dim_max = False):
"""This function computes and returns the persistence of the simplicial complex.
@@ -621,18 +675,17 @@ cdef class SimplexTree:
return (normal0, normals, infinite0, infinites)
def collapse_edges(self, nb_iterations = 1):
- """Assuming the simplex tree is a 1-skeleton graph, this method collapse edges (simplices of higher dimension
- are ignored) and resets the simplex tree from the remaining edges.
- A good candidate is to build a simplex tree on top of a :class:`~gudhi.RipsComplex` of dimension 1 before
- collapsing edges
+ """Assuming the complex is a graph (simplices of higher dimension are ignored), this method implicitly
+ interprets it as the 1-skeleton of a flag complex, and replaces it with another (smaller) graph whose
+ expansion has the same persistent homology, using a technique known as edge collapses
+ (see :cite:`edgecollapsearxiv`).
+
+ A natural application is to get a simplex tree of dimension 1 from :class:`~gudhi.RipsComplex`,
+ then collapse edges, perform :meth:`expansion()` and finally compute persistence
(cf. :download:`rips_complex_edge_collapse_example.py <../example/rips_complex_edge_collapse_example.py>`).
- For implementation details, please refer to :cite:`edgecollapsesocg2020`.
:param nb_iterations: The number of edge collapse iterations to perform. Default is 1.
:type nb_iterations: int
-
- :note: collapse_edges method requires `Eigen <installation.html#eigen>`_ >= 3.1.0 and an exception is thrown
- if this method is not available.
"""
# Backup old pointer
cdef Simplex_tree_interface_full_featured* ptr = self.get_ptr()
@@ -642,3 +695,13 @@ cdef class SimplexTree:
self.thisptr = <intptr_t>(ptr.collapse_edges(nb_iter))
# Delete old pointer
del ptr
+
+ def __eq__(self, other:SimplexTree):
+ """Test for structural equality
+ :returns: True if the 2 simplex trees are equal, False otherwise.
+ :rtype: bool
+ """
+ return dereference(self.get_ptr()) == dereference(other.get_ptr())
+
+cdef intptr_t _get_copy_intptr(SimplexTree stree) nogil:
+ return <intptr_t>(new Simplex_tree_interface_full_featured(dereference(stree.get_ptr())))
diff --git a/src/python/gudhi/sklearn/__init__.py b/src/python/gudhi/sklearn/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/python/gudhi/sklearn/__init__.py
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py
new file mode 100644
index 00000000..672af278
--- /dev/null
+++ b/src/python/gudhi/sklearn/cubical_persistence.py
@@ -0,0 +1,110 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Vincent Rouvreau
+#
+# Copyright (C) 2021 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+from .. import CubicalComplex
+from sklearn.base import BaseEstimator, TransformerMixin
+
+import numpy as np
+# joblib is required by scikit-learn
+from joblib import Parallel, delayed
+
+# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
+# sequenceDiagram
+# USER->>CubicalPersistence: fit_transform(X)
+# CubicalPersistence->>thread1: _tranform(X[0])
+# CubicalPersistence->>thread2: _tranform(X[1])
+# Note right of CubicalPersistence: ...
+# thread1->>CubicalPersistence: [array( H0(X[0]) ), array( H1(X[0]) )]
+# thread2->>CubicalPersistence: [array( H0(X[1]) ), array( H1(X[1]) )]
+# Note right of CubicalPersistence: ...
+# CubicalPersistence->>USER: [[array( H0(X[0]) ), array( H1(X[0]) )],<br/> [array( H0(X[1]) ), array( H1(X[1]) )],<br/> ...]
+
+
+class CubicalPersistence(BaseEstimator, TransformerMixin):
+ """
+ This is a class for computing the persistence diagrams from a cubical complex.
+ """
+
+ def __init__(
+ self,
+ homology_dimensions,
+ newshape=None,
+ homology_coeff_field=11,
+ min_persistence=0.0,
+ n_jobs=None,
+ ):
+ """
+ Constructor for the CubicalPersistence class.
+
+ Parameters:
+ homology_dimensions (int or list of int): The returned persistence diagrams dimension(s).
+ Short circuit the use of :class:`~gudhi.representations.preprocessing.DimensionSelector` when only one
+ dimension matters (in other words, when `homology_dimensions` is an int).
+ newshape (tuple of ints): If cells filtration values require to be reshaped
+ (cf. :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform`), set `newshape`
+ to perform `numpy.reshape(X, newshape, order='C')` in
+ :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform` method.
+ homology_coeff_field (int): The homology coefficient field. Must be a prime number. Default value is 11.
+ min_persistence (float): The minimum persistence value to take into account (strictly greater than
+ `min_persistence`). Default value is `0.0`. Set `min_persistence` to `-1.0` to see all values.
+ n_jobs (int): cf. https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html
+ """
+ self.homology_dimensions = homology_dimensions
+ self.newshape = newshape
+ self.homology_coeff_field = homology_coeff_field
+ self.min_persistence = min_persistence
+ self.n_jobs = n_jobs
+
+ def fit(self, X, Y=None):
+ """
+ Nothing to be done, but useful when included in a scikit-learn Pipeline.
+ """
+ return self
+
+ def __transform(self, cells):
+ cubical_complex = CubicalComplex(top_dimensional_cells=cells)
+ cubical_complex.compute_persistence(
+ homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence
+ )
+ return [
+ cubical_complex.persistence_intervals_in_dimension(dim) for dim in self.homology_dimensions
+ ]
+
+ def __transform_only_this_dim(self, cells):
+ cubical_complex = CubicalComplex(top_dimensional_cells=cells)
+ cubical_complex.compute_persistence(
+ homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence
+ )
+ return cubical_complex.persistence_intervals_in_dimension(self.homology_dimensions)
+
+ def transform(self, X, Y=None):
+ """Compute all the cubical complexes and their associated persistence diagrams.
+
+ :param X: List of cells filtration values (`numpy.reshape(X, newshape, order='C'` if `newshape` is set with a tuple of ints).
+ :type X: list of list of float OR list of numpy.ndarray
+
+ :return: Persistence diagrams in the format:
+
+ - If `homology_dimensions` was set to `n`: `[array( Hn(X[0]) ), array( Hn(X[1]) ), ...]`
+ - If `homology_dimensions` was set to `[i, j]`: `[[array( Hi(X[0]) ), array( Hj(X[0]) )], [array( Hi(X[1]) ), array( Hj(X[1]) )], ...]`
+ :rtype: list of (,2) array_like or list of list of (,2) array_like
+ """
+ if self.newshape is not None:
+ X = np.reshape(X, self.newshape, order='C')
+
+ # Depends on homology_dimensions is an integer or a list of integer (else case)
+ if isinstance(self.homology_dimensions, int):
+ # threads is preferred as cubical construction and persistence computation releases the GIL
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(
+ delayed(self.__transform_only_this_dim)(cells) for cells in X
+ )
+ else:
+ # threads is preferred as cubical construction and persistence computation releases the GIL
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(cells) for cells in X)
+
diff --git a/src/python/gudhi/tensorflow/__init__.py b/src/python/gudhi/tensorflow/__init__.py
new file mode 100644
index 00000000..1599cf52
--- /dev/null
+++ b/src/python/gudhi/tensorflow/__init__.py
@@ -0,0 +1,5 @@
+from .cubical_layer import CubicalLayer
+from .lower_star_simplex_tree_layer import LowerStarSimplexTreeLayer
+from .rips_layer import RipsLayer
+
+__all__ = ["LowerStarSimplexTreeLayer", "RipsLayer", "CubicalLayer"]
diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py
new file mode 100644
index 00000000..3304e719
--- /dev/null
+++ b/src/python/gudhi/tensorflow/cubical_layer.py
@@ -0,0 +1,82 @@
+import numpy as np
+import tensorflow as tf
+from ..cubical_complex import CubicalComplex
+
+######################
+# Cubical filtration #
+######################
+
+# The parameters of the model are the pixel values.
+
+def _Cubical(Xflat, Xdim, dimensions, homology_coeff_field):
+ # Parameters: Xflat (flattened image),
+ # Xdim (shape of non-flattened image)
+ # dimensions (homology dimensions)
+
+ # Compute the persistence pairs with Gudhi
+ # We reverse the dimensions because CubicalComplex uses Fortran ordering
+ cc = CubicalComplex(dimensions=Xdim[::-1], top_dimensional_cells=Xflat)
+ cc.compute_persistence(homology_coeff_field=homology_coeff_field)
+
+ # Retrieve and ouput image indices/pixels corresponding to positive and negative simplices
+ cof_pp = cc.cofaces_of_persistence_pairs()
+
+ L_cofs = []
+ for dim in dimensions:
+
+ try:
+ cof = cof_pp[0][dim]
+ except IndexError:
+ cof = np.array([])
+
+ L_cofs.append(np.array(cof, dtype=np.int32))
+
+ return L_cofs
+
+class CubicalLayer(tf.keras.layers.Layer):
+ """
+ TensorFlow layer for computing the persistent homology of a cubical complex
+ """
+ def __init__(self, homology_dimensions, min_persistence=None, homology_coeff_field=11, **kwargs):
+ """
+ Constructor for the CubicalLayer class
+
+ Parameters:
+ homology_dimensions (List[int]): list of homology dimensions
+ min_persistence (List[float]): minimum distance-to-diagonal of the points in the output persistence diagrams (default None, in which case 0. is used for all dimensions)
+ homology_coeff_field (int): homology field coefficient. Must be a prime number. Default value is 11. Max is 46337.
+ """
+ super().__init__(dynamic=True, **kwargs)
+ self.dimensions = homology_dimensions
+ self.min_persistence = min_persistence if min_persistence != None else [0.] * len(self.dimensions)
+ self.hcf = homology_coeff_field
+ assert len(self.min_persistence) == len(self.dimensions)
+
+ def call(self, X):
+ """
+ Compute persistence diagram associated to a cubical complex filtered by some pixel values
+
+ Parameters:
+ X (TensorFlow variable): pixel values of the cubical complex
+
+ Returns:
+ List[Tuple[tf.Tensor,tf.Tensor]]: List of cubical persistence diagrams. The length of this list is the same than that of dimensions, i.e., there is one persistence diagram per homology dimension provided in the input list dimensions. Moreover, the finite and essential parts of the persistence diagrams are provided separately: each element of this list is a tuple of size two that contains the finite and essential parts of the corresponding persistence diagram, of shapes [num_finite_points, 2] and [num_essential_points, 1] respectively. Note that the essential part is always empty in cubical persistence diagrams, except in homology dimension zero, where the essential part always contains a single point, with abscissa equal to the smallest value in the complex, and infinite ordinate
+ """
+ # Compute pixels associated to positive and negative simplices
+ # Don't compute gradient for this operation
+ Xflat = tf.reshape(X, [-1])
+ Xdim, Xflat_numpy = X.shape, Xflat.numpy()
+ indices_list = _Cubical(Xflat_numpy, Xdim, self.dimensions, self.hcf)
+ index_essential = np.argmin(Xflat_numpy) # index of minimum pixel value for essential persistence diagram
+ # Get persistence diagram by simply picking the corresponding entries in the image
+ self.dgms = []
+ for idx_dim, dimension in enumerate(self.dimensions):
+ finite_dgm = tf.reshape(tf.gather(Xflat, indices_list[idx_dim]), [-1,2])
+ essential_dgm = tf.reshape(tf.gather(Xflat, index_essential), [-1,1]) if dimension == 0 else tf.zeros([0, 1])
+ min_pers = self.min_persistence[idx_dim]
+ if min_pers >= 0:
+ persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
+ self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices), [-1,2]), essential_dgm))
+ else:
+ self.dgms.append((finite_dgm, essential_dgm))
+ return self.dgms
diff --git a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
new file mode 100644
index 00000000..5a8e5b75
--- /dev/null
+++ b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
@@ -0,0 +1,87 @@
+import numpy as np
+import tensorflow as tf
+
+#########################################
+# Lower star filtration on simplex tree #
+#########################################
+
+# The parameters of the model are the vertex function values of the simplex tree.
+
+def _LowerStarSimplexTree(simplextree, filtration, dimensions, homology_coeff_field):
+ # Parameters: simplextree (simplex tree on which to compute persistence)
+ # filtration (function values on the vertices of st),
+ # dimensions (homology dimensions),
+ # homology_coeff_field (homology field coefficient)
+
+ simplextree.reset_filtration(-np.inf, 0)
+
+ # Assign new filtration values
+ for i in range(simplextree.num_vertices()):
+ simplextree.assign_filtration([i], filtration[i])
+ simplextree.make_filtration_non_decreasing()
+
+ # Compute persistence diagram
+ simplextree.compute_persistence(homology_coeff_field=homology_coeff_field)
+
+ # Get vertex pairs for optimization. First, get all simplex pairs
+ pairs = simplextree.lower_star_persistence_generators()
+
+ L_indices = []
+ for dimension in dimensions:
+
+ finite_pairs = pairs[0][dimension] if len(pairs[0]) >= dimension+1 else np.empty(shape=[0,2])
+ essential_pairs = pairs[1][dimension] if len(pairs[1]) >= dimension+1 else np.empty(shape=[0,1])
+
+ finite_indices = np.array(finite_pairs.flatten(), dtype=np.int32)
+ essential_indices = np.array(essential_pairs.flatten(), dtype=np.int32)
+
+ L_indices.append((finite_indices, essential_indices))
+
+ return L_indices
+
+class LowerStarSimplexTreeLayer(tf.keras.layers.Layer):
+ """
+ TensorFlow layer for computing lower-star persistence out of a simplex tree
+ """
+ def __init__(self, simplextree, homology_dimensions, min_persistence=None, homology_coeff_field=11, **kwargs):
+ """
+ Constructor for the LowerStarSimplexTreeLayer class
+
+ Parameters:
+ simplextree (gudhi.SimplexTree): underlying simplex tree. Its vertices MUST be named with integers from 0 to n-1, where n is its number of vertices. Note that its filtration values are modified in each call of the class.
+ homology_dimensions (List[int]): list of homology dimensions
+ min_persistence (List[float]): minimum distance-to-diagonal of the points in the output persistence diagrams (default None, in which case 0. is used for all dimensions)
+ homology_coeff_field (int): homology field coefficient. Must be a prime number. Default value is 11. Max is 46337.
+ """
+ super().__init__(dynamic=True, **kwargs)
+ self.dimensions = homology_dimensions
+ self.simplextree = simplextree
+ self.min_persistence = min_persistence if min_persistence != None else [0. for _ in range(len(self.dimensions))]
+ self.hcf = homology_coeff_field
+ assert len(self.min_persistence) == len(self.dimensions)
+
+ def call(self, filtration):
+ """
+ Compute lower-star persistence diagram associated to a function defined on the vertices of the simplex tree
+
+ Parameters:
+ F (TensorFlow variable): filter function values over the vertices of the simplex tree. The ith entry of F corresponds to vertex i in self.simplextree
+
+ Returns:
+ List[Tuple[tf.Tensor,tf.Tensor]]: List of lower-star persistence diagrams. The length of this list is the same than that of dimensions, i.e., there is one persistence diagram per homology dimension provided in the input list dimensions. Moreover, the finite and essential parts of the persistence diagrams are provided separately: each element of this list is a tuple of size two that contains the finite and essential parts of the corresponding persistence diagram, of shapes [num_finite_points, 2] and [num_essential_points, 1] respectively
+ """
+ # Don't try to compute gradients for the vertex pairs
+ indices = _LowerStarSimplexTree(self.simplextree, filtration.numpy(), self.dimensions, self.hcf)
+ # Get persistence diagrams
+ self.dgms = []
+ for idx_dim, dimension in enumerate(self.dimensions):
+ finite_dgm = tf.reshape(tf.gather(filtration, indices[idx_dim][0]), [-1,2])
+ essential_dgm = tf.reshape(tf.gather(filtration, indices[idx_dim][1]), [-1,1])
+ min_pers = self.min_persistence[idx_dim]
+ if min_pers >= 0:
+ persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
+ self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices),[-1,2]), essential_dgm))
+ else:
+ self.dgms.append((finite_dgm, essential_dgm))
+ return self.dgms
+
diff --git a/src/python/gudhi/tensorflow/rips_layer.py b/src/python/gudhi/tensorflow/rips_layer.py
new file mode 100644
index 00000000..2a73472c
--- /dev/null
+++ b/src/python/gudhi/tensorflow/rips_layer.py
@@ -0,0 +1,93 @@
+import numpy as np
+import tensorflow as tf
+from ..rips_complex import RipsComplex
+
+############################
+# Vietoris-Rips filtration #
+############################
+
+# The parameters of the model are the point coordinates.
+
+def _Rips(DX, max_edge, dimensions, homology_coeff_field):
+ # Parameters: DX (distance matrix),
+ # max_edge (maximum edge length for Rips filtration),
+ # dimensions (homology dimensions)
+
+ # Compute the persistence pairs with Gudhi
+ rc = RipsComplex(distance_matrix=DX, max_edge_length=max_edge)
+ st = rc.create_simplex_tree(max_dimension=max(dimensions)+1)
+ st.compute_persistence(homology_coeff_field=homology_coeff_field)
+ pairs = st.flag_persistence_generators()
+
+ L_indices = []
+ for dimension in dimensions:
+
+ if dimension == 0:
+ finite_pairs = pairs[0]
+ essential_pairs = pairs[2]
+ else:
+ finite_pairs = pairs[1][dimension-1] if len(pairs[1]) >= dimension else np.empty(shape=[0,4])
+ essential_pairs = pairs[3][dimension-1] if len(pairs[3]) >= dimension else np.empty(shape=[0,2])
+
+ finite_indices = np.array(finite_pairs.flatten(), dtype=np.int32)
+ essential_indices = np.array(essential_pairs.flatten(), dtype=np.int32)
+
+ L_indices.append((finite_indices, essential_indices))
+
+ return L_indices
+
+class RipsLayer(tf.keras.layers.Layer):
+ """
+ TensorFlow layer for computing Rips persistence out of a point cloud
+ """
+ def __init__(self, homology_dimensions, maximum_edge_length=np.inf, min_persistence=None, homology_coeff_field=11, **kwargs):
+ """
+ Constructor for the RipsLayer class
+
+ Parameters:
+ maximum_edge_length (float): maximum edge length for the Rips complex
+ homology_dimensions (List[int]): list of homology dimensions
+ min_persistence (List[float]): minimum distance-to-diagonal of the points in the output persistence diagrams (default None, in which case 0. is used for all dimensions)
+ homology_coeff_field (int): homology field coefficient. Must be a prime number. Default value is 11. Max is 46337.
+ """
+ super().__init__(dynamic=True, **kwargs)
+ self.max_edge = maximum_edge_length
+ self.dimensions = homology_dimensions
+ self.min_persistence = min_persistence if min_persistence != None else [0. for _ in range(len(self.dimensions))]
+ self.hcf = homology_coeff_field
+ assert len(self.min_persistence) == len(self.dimensions)
+
+ def call(self, X):
+ """
+ Compute Rips persistence diagram associated to a point cloud
+
+ Parameters:
+ X (TensorFlow variable): point cloud of shape [number of points, number of dimensions]
+
+ Returns:
+ List[Tuple[tf.Tensor,tf.Tensor]]: List of Rips persistence diagrams. The length of this list is the same than that of dimensions, i.e., there is one persistence diagram per homology dimension provided in the input list dimensions. Moreover, the finite and essential parts of the persistence diagrams are provided separately: each element of this list is a tuple of size two that contains the finite and essential parts of the corresponding persistence diagram, of shapes [num_finite_points, 2] and [num_essential_points, 1] respectively
+ """
+ # Compute distance matrix
+ DX = tf.norm(tf.expand_dims(X, 1)-tf.expand_dims(X, 0), axis=2)
+ # Compute vertices associated to positive and negative simplices
+ # Don't compute gradient for this operation
+ indices = _Rips(DX.numpy(), self.max_edge, self.dimensions, self.hcf)
+ # Get persistence diagrams by simply picking the corresponding entries in the distance matrix
+ self.dgms = []
+ for idx_dim, dimension in enumerate(self.dimensions):
+ cur_idx = indices[idx_dim]
+ if dimension > 0:
+ finite_dgm = tf.reshape(tf.gather_nd(DX, tf.reshape(cur_idx[0], [-1,2])), [-1,2])
+ essential_dgm = tf.reshape(tf.gather_nd(DX, tf.reshape(cur_idx[1], [-1,2])), [-1,1])
+ else:
+ reshaped_cur_idx = tf.reshape(cur_idx[0], [-1,3])
+ finite_dgm = tf.concat([tf.zeros([reshaped_cur_idx.shape[0],1]), tf.reshape(tf.gather_nd(DX, reshaped_cur_idx[:,1:]), [-1,1])], axis=1)
+ essential_dgm = tf.zeros([cur_idx[1].shape[0],1])
+ min_pers = self.min_persistence[idx_dim]
+ if min_pers >= 0:
+ persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
+ self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices),[-1,2]), essential_dgm))
+ else:
+ self.dgms.append((finite_dgm, essential_dgm))
+ return self.dgms
+
diff --git a/src/python/gudhi/wasserstein/barycenter.py b/src/python/gudhi/wasserstein/barycenter.py
index d67bcde7..bb6e641e 100644
--- a/src/python/gudhi/wasserstein/barycenter.py
+++ b/src/python/gudhi/wasserstein/barycenter.py
@@ -37,7 +37,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
:param init: The initial value for barycenter estimate.
If ``None``, init is made on a random diagram from the dataset.
Otherwise, it can be an ``int`` (then initialization is made on ``pdiagset[init]``)
- or a `(n x 2)` ``numpy.array`` enconding a persistence diagram with `n` points.
+ or a `(n x 2)` ``numpy.array`` encoding a persistence diagram with `n` points.
:type init: ``int``, or (n x 2) ``np.array``
:param verbose: if ``True``, returns additional information about the barycenter.
:type verbose: boolean
@@ -45,7 +45,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
(local minimum of the energy function).
If ``pdiagset`` is empty, returns ``None``.
If verbose, returns a couple ``(Y, log)`` where ``Y`` is the barycenter estimate,
- and ``log`` is a ``dict`` that contains additional informations:
+ and ``log`` is a ``dict`` that contains additional information:
- `"groupings"`, a list of list of pairs ``(i,j)``. Namely, ``G[k] = [...(i, j)...]``, where ``(i,j)`` indicates that `pdiagset[k][i]`` is matched to ``Y[j]`` if ``i = -1`` or ``j = -1``, it means they represent the diagonal.
@@ -73,7 +73,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
nb_iter = 0
- converged = False # stoping criterion
+ converged = False # stopping criterion
while not converged:
nb_iter += 1
K = len(Y) # current nb of points in Y (some might be on diagonal)
diff --git a/src/python/gudhi/weighted_rips_complex.py b/src/python/gudhi/weighted_rips_complex.py
index 0541572b..16f63c3d 100644
--- a/src/python/gudhi/weighted_rips_complex.py
+++ b/src/python/gudhi/weighted_rips_complex.py
@@ -12,9 +12,11 @@ from gudhi import SimplexTree
class WeightedRipsComplex:
"""
Class to generate a weighted Rips complex from a distance matrix and weights on vertices,
- in the way described in :cite:`dtmfiltrations`.
+ in the way described in :cite:`dtmfiltrations` with `p=1`. The filtration value of vertex `i` is `2*weights[i]`,
+ and the filtration value of edge `ij` is `distance_matrix[i][j]+weights[i]+weights[j]`,
+ or the maximum of the filtrations of its extremities, whichever is largest.
Remark that all the filtration values are doubled compared to the definition in the paper
- for the consistency with RipsComplex.
+ for consistency with RipsComplex.
"""
def __init__(self,
distance_matrix,
diff --git a/src/python/include/Alpha_complex_factory.h b/src/python/include/Alpha_complex_factory.h
index 3405fdd6..3d20aa8f 100644
--- a/src/python/include/Alpha_complex_factory.h
+++ b/src/python/include/Alpha_complex_factory.h
@@ -31,15 +31,34 @@ namespace Gudhi {
namespace alpha_complex {
-template <typename CgalPointType>
-std::vector<double> pt_cgal_to_cython(CgalPointType const& point) {
- std::vector<double> vd;
- vd.reserve(point.dimension());
- for (auto coord = point.cartesian_begin(); coord != point.cartesian_end(); coord++)
- vd.push_back(CGAL::to_double(*coord));
- return vd;
-}
+// template Functor that transforms a CGAL point to a vector of double as expected by cython
+template<typename CgalPointType, bool Weighted>
+struct Point_cgal_to_cython;
+
+// Specialized Unweighted Functor
+template<typename CgalPointType>
+struct Point_cgal_to_cython<CgalPointType, false> {
+ std::vector<double> operator()(CgalPointType const& point) const
+ {
+ std::vector<double> vd;
+ vd.reserve(point.dimension());
+ for (auto coord = point.cartesian_begin(); coord != point.cartesian_end(); coord++)
+ vd.push_back(CGAL::to_double(*coord));
+ return vd;
+ }
+};
+// Specialized Weighted Functor
+template<typename CgalPointType>
+struct Point_cgal_to_cython<CgalPointType, true> {
+ std::vector<double> operator()(CgalPointType const& weighted_point) const
+ {
+ const auto& point = weighted_point.point();
+ return Point_cgal_to_cython<decltype(point), false>()(point);
+ }
+};
+
+// Function that transforms a cython point (aka. a vector of double) to a CGAL point
template <typename CgalPointType>
static CgalPointType pt_cython_to_cgal(std::vector<double> const& vec) {
return CgalPointType(vec.size(), vec.begin(), vec.end());
@@ -51,24 +70,35 @@ class Abstract_alpha_complex {
virtual bool create_simplex_tree(Simplex_tree_interface<>* simplex_tree, double max_alpha_square,
bool default_filtration_value) = 0;
+
+ virtual std::size_t num_vertices() const = 0;
virtual ~Abstract_alpha_complex() = default;
};
-class Exact_Alphacomplex_dD final : public Abstract_alpha_complex {
+template <bool Weighted = false>
+class Exact_alpha_complex_dD final : public Abstract_alpha_complex {
private:
using Kernel = CGAL::Epeck_d<CGAL::Dynamic_dimension_tag>;
- using Point = typename Kernel::Point_d;
+ using Bare_point = typename Kernel::Point_d;
+ using Point = std::conditional_t<Weighted, typename Kernel::Weighted_point_d,
+ typename Kernel::Point_d>;
public:
- Exact_Alphacomplex_dD(const std::vector<std::vector<double>>& points, bool exact_version)
+ Exact_alpha_complex_dD(const std::vector<std::vector<double>>& points, bool exact_version)
+ : exact_version_(exact_version),
+ alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal<Bare_point>)) {
+ }
+
+ Exact_alpha_complex_dD(const std::vector<std::vector<double>>& points,
+ const std::vector<double>& weights, bool exact_version)
: exact_version_(exact_version),
- alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal<Point>)) {
+ alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal<Bare_point>), weights) {
}
virtual std::vector<double> get_point(int vh) override {
- Point const& point = alpha_complex_.get_point(vh);
- return pt_cgal_to_cython(point);
+ // Can be a Weighted or a Bare point in function of Weighted
+ return Point_cgal_to_cython<Point, Weighted>()(alpha_complex_.get_point(vh));
}
virtual bool create_simplex_tree(Simplex_tree_interface<>* simplex_tree, double max_alpha_square,
@@ -76,65 +106,49 @@ class Exact_Alphacomplex_dD final : public Abstract_alpha_complex {
return alpha_complex_.create_complex(*simplex_tree, max_alpha_square, exact_version_, default_filtration_value);
}
+ virtual std::size_t num_vertices() const {
+ return alpha_complex_.num_vertices();
+ }
+
private:
bool exact_version_;
- Alpha_complex<Kernel> alpha_complex_;
+ Alpha_complex<Kernel, Weighted> alpha_complex_;
};
-class Inexact_Alphacomplex_dD final : public Abstract_alpha_complex {
+template <bool Weighted = false>
+class Inexact_alpha_complex_dD final : public Abstract_alpha_complex {
private:
using Kernel = CGAL::Epick_d<CGAL::Dynamic_dimension_tag>;
- using Point = typename Kernel::Point_d;
+ using Bare_point = typename Kernel::Point_d;
+ using Point = std::conditional_t<Weighted, typename Kernel::Weighted_point_d,
+ typename Kernel::Point_d>;
public:
- Inexact_Alphacomplex_dD(const std::vector<std::vector<double>>& points, bool exact_version)
- : exact_version_(exact_version),
- alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal<Point>)) {
+ Inexact_alpha_complex_dD(const std::vector<std::vector<double>>& points)
+ : alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal<Bare_point>)) {
+ }
+
+ Inexact_alpha_complex_dD(const std::vector<std::vector<double>>& points, const std::vector<double>& weights)
+ : alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal<Bare_point>), weights) {
}
virtual std::vector<double> get_point(int vh) override {
- Point const& point = alpha_complex_.get_point(vh);
- return pt_cgal_to_cython(point);
+ // Can be a Weighted or a Bare point in function of Weighted
+ return Point_cgal_to_cython<Point, Weighted>()(alpha_complex_.get_point(vh));
}
virtual bool create_simplex_tree(Simplex_tree_interface<>* simplex_tree, double max_alpha_square,
bool default_filtration_value) override {
- return alpha_complex_.create_complex(*simplex_tree, max_alpha_square, exact_version_, default_filtration_value);
+ return alpha_complex_.create_complex(*simplex_tree, max_alpha_square, false, default_filtration_value);
}
- private:
- bool exact_version_;
- Alpha_complex<Kernel> alpha_complex_;
-};
-
-template <complexity Complexity>
-class Alphacomplex_3D final : public Abstract_alpha_complex {
- private:
- using Point = typename Alpha_complex_3d<Complexity, false, false>::Bare_point_3;
-
- static Point pt_cython_to_cgal_3(std::vector<double> const& vec) {
- return Point(vec[0], vec[1], vec[2]);
- }
-
- public:
- Alphacomplex_3D(const std::vector<std::vector<double>>& points)
- : alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal_3)) {
- }
-
- virtual std::vector<double> get_point(int vh) override {
- Point const& point = alpha_complex_.get_point(vh);
- return pt_cgal_to_cython(point);
- }
-
- virtual bool create_simplex_tree(Simplex_tree_interface<>* simplex_tree, double max_alpha_square,
- bool default_filtration_value) override {
- return alpha_complex_.create_complex(*simplex_tree, max_alpha_square);
+ virtual std::size_t num_vertices() const {
+ return alpha_complex_.num_vertices();
}
private:
- Alpha_complex_3d<Complexity, false, false> alpha_complex_;
+ Alpha_complex<Kernel, Weighted> alpha_complex_;
};
-
} // namespace alpha_complex
} // namespace Gudhi
diff --git a/src/python/include/Alpha_complex_interface.h b/src/python/include/Alpha_complex_interface.h
index 23be194d..469b91ce 100644
--- a/src/python/include/Alpha_complex_interface.h
+++ b/src/python/include/Alpha_complex_interface.h
@@ -27,10 +27,23 @@ namespace alpha_complex {
class Alpha_complex_interface {
public:
- Alpha_complex_interface(const std::vector<std::vector<double>>& points, bool fast_version, bool exact_version)
- : points_(points),
- fast_version_(fast_version),
- exact_version_(exact_version) {
+ Alpha_complex_interface(const std::vector<std::vector<double>>& points,
+ const std::vector<double>& weights,
+ bool fast_version, bool exact_version) {
+ const bool weighted = (weights.size() > 0);
+ if (fast_version) {
+ if (weighted) {
+ alpha_ptr_ = std::make_unique<Inexact_alpha_complex_dD<true>>(points, weights);
+ } else {
+ alpha_ptr_ = std::make_unique<Inexact_alpha_complex_dD<false>>(points);
+ }
+ } else {
+ if (weighted) {
+ alpha_ptr_ = std::make_unique<Exact_alpha_complex_dD<true>>(points, weights, exact_version);
+ } else {
+ alpha_ptr_ = std::make_unique<Exact_alpha_complex_dD<false>>(points, exact_version);
+ }
+ }
}
std::vector<double> get_point(int vh) {
@@ -39,38 +52,23 @@ class Alpha_complex_interface {
void create_simplex_tree(Simplex_tree_interface<>* simplex_tree, double max_alpha_square,
bool default_filtration_value) {
- if (points_.size() > 0) {
- std::size_t dimension = points_[0].size();
- if (dimension == 3 && !default_filtration_value) {
- if (fast_version_)
- alpha_ptr_ = std::make_unique<Alphacomplex_3D<Gudhi::alpha_complex::complexity::FAST>>(points_);
- else if (exact_version_)
- alpha_ptr_ = std::make_unique<Alphacomplex_3D<Gudhi::alpha_complex::complexity::EXACT>>(points_);
- else
- alpha_ptr_ = std::make_unique<Alphacomplex_3D<Gudhi::alpha_complex::complexity::SAFE>>(points_);
- if (!alpha_ptr_->create_simplex_tree(simplex_tree, max_alpha_square, default_filtration_value)) {
- // create_simplex_tree will fail if all points are on a plane - Retry with dD by setting dimension to 2
- dimension--;
- alpha_ptr_.reset();
- }
- }
- // Not ** else ** because we have to take into account if 3d fails
- if (dimension != 3 || default_filtration_value) {
- if (fast_version_) {
- alpha_ptr_ = std::make_unique<Inexact_Alphacomplex_dD>(points_, exact_version_);
- } else {
- alpha_ptr_ = std::make_unique<Exact_Alphacomplex_dD>(points_, exact_version_);
- }
- alpha_ptr_->create_simplex_tree(simplex_tree, max_alpha_square, default_filtration_value);
- }
- }
+ // Nothing to be done in case of an empty point set
+ if (alpha_ptr_->num_vertices() > 0)
+ alpha_ptr_->create_simplex_tree(simplex_tree, max_alpha_square, default_filtration_value);
+ }
+
+ static void set_float_relative_precision(double precision) {
+ // cf. Exact_alpha_complex_dD kernel type in Alpha_complex_factory.h
+ CGAL::Epeck_d<CGAL::Dynamic_dimension_tag>::FT::set_relative_precision_of_to_double(precision);
+ }
+
+ static double get_float_relative_precision() {
+ // cf. Exact_alpha_complex_dD kernel type in Alpha_complex_factory.h
+ return CGAL::Epeck_d<CGAL::Dynamic_dimension_tag>::FT::get_relative_precision_of_to_double();
}
private:
std::unique_ptr<Abstract_alpha_complex> alpha_ptr_;
- std::vector<std::vector<double>> points_;
- bool fast_version_;
- bool exact_version_;
};
} // namespace alpha_complex
diff --git a/src/python/include/Persistent_cohomology_interface.h b/src/python/include/Persistent_cohomology_interface.h
index e5a3dfba..945378a0 100644
--- a/src/python/include/Persistent_cohomology_interface.h
+++ b/src/python/include/Persistent_cohomology_interface.h
@@ -12,6 +12,8 @@
#define INCLUDE_PERSISTENT_COHOMOLOGY_INTERFACE_H_
#include <gudhi/Persistent_cohomology.h>
+#include <gudhi/Simplex_tree.h> // for Extended_simplex_type
+
#include <cstdlib>
#include <vector>
@@ -223,6 +225,44 @@ persistent_cohomology::Persistent_cohomology<FilteredComplex, persistent_cohomol
return out;
}
+ using Filtration_value = typename FilteredComplex::Filtration_value;
+ using Birth_death = std::pair<Filtration_value, Filtration_value>;
+ using Persistence_subdiagrams = std::vector<std::vector<std::pair<int, Birth_death>>>;
+
+ Persistence_subdiagrams compute_extended_persistence_subdiagrams(Filtration_value min_persistence){
+ Persistence_subdiagrams pers_subs(4);
+ auto const& persistent_pairs = Base::get_persistent_pairs();
+ for (auto pair : persistent_pairs) {
+ std::pair<Filtration_value, Extended_simplex_type> px = stptr_->decode_extended_filtration(stptr_->filtration(get<0>(pair)),
+ stptr_->efd);
+ std::pair<Filtration_value, Extended_simplex_type> py = stptr_->decode_extended_filtration(stptr_->filtration(get<1>(pair)),
+ stptr_->efd);
+ std::pair<int, Birth_death> pd_point = std::make_pair(stptr_->dimension(get<0>(pair)),
+ std::make_pair(px.first, py.first));
+ if(std::abs(px.first - py.first) > min_persistence){
+ //Ordinary
+ if (px.second == Extended_simplex_type::UP && py.second == Extended_simplex_type::UP){
+ pers_subs[0].push_back(pd_point);
+ }
+ // Relative
+ else if (px.second == Extended_simplex_type::DOWN && py.second == Extended_simplex_type::DOWN){
+ pers_subs[1].push_back(pd_point);
+ }
+ else{
+ // Extended+
+ if (px.first < py.first){
+ pers_subs[2].push_back(pd_point);
+ }
+ //Extended-
+ else{
+ pers_subs[3].push_back(pd_point);
+ }
+ }
+ }
+ }
+ return pers_subs;
+ }
+
private:
// A copy
FilteredComplex* stptr_;
diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h
index 629f6083..3848c5ad 100644
--- a/src/python/include/Simplex_tree_interface.h
+++ b/src/python/include/Simplex_tree_interface.h
@@ -15,9 +15,7 @@
#include <gudhi/distance_functions.h>
#include <gudhi/Simplex_tree.h>
#include <gudhi/Points_off_io.h>
-#ifdef GUDHI_USE_EIGEN3
#include <gudhi/Flag_complex_edge_collapser.h>
-#endif
#include <iostream>
#include <vector>
@@ -42,6 +40,7 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
using Complex_simplex_iterator = typename Base::Complex_simplex_iterator;
using Extended_filtration_data = typename Base::Extended_filtration_data;
using Boundary_simplex_iterator = typename Base::Boundary_simplex_iterator;
+ typedef bool (*blocker_func_t)(Simplex simplex, void *user_data);
public:
@@ -133,38 +132,7 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
return;
}
- std::vector<std::vector<std::pair<int, std::pair<Filtration_value, Filtration_value>>>> compute_extended_persistence_subdiagrams(const std::vector<std::pair<int, std::pair<Filtration_value, Filtration_value>>>& dgm, Filtration_value min_persistence){
- std::vector<std::vector<std::pair<int, std::pair<Filtration_value, Filtration_value>>>> new_dgm(4);
- for (unsigned int i = 0; i < dgm.size(); i++){
- std::pair<Filtration_value, Extended_simplex_type> px = this->decode_extended_filtration(dgm[i].second.first, this->efd);
- std::pair<Filtration_value, Extended_simplex_type> py = this->decode_extended_filtration(dgm[i].second.second, this->efd);
- std::pair<int, std::pair<Filtration_value, Filtration_value>> pd_point = std::make_pair(dgm[i].first, std::make_pair(px.first, py.first));
- if(std::abs(px.first - py.first) > min_persistence){
- //Ordinary
- if (px.second == Extended_simplex_type::UP && py.second == Extended_simplex_type::UP){
- new_dgm[0].push_back(pd_point);
- }
- // Relative
- else if (px.second == Extended_simplex_type::DOWN && py.second == Extended_simplex_type::DOWN){
- new_dgm[1].push_back(pd_point);
- }
- else{
- // Extended+
- if (px.first < py.first){
- new_dgm[2].push_back(pd_point);
- }
- //Extended-
- else{
- new_dgm[3].push_back(pd_point);
- }
- }
- }
- }
- return new_dgm;
- }
-
Simplex_tree_interface* collapse_edges(int nb_collapse_iteration) {
-#ifdef GUDHI_USE_EIGEN3
using Filtered_edge = std::tuple<Vertex_handle, Vertex_handle, Filtration_value>;
std::vector<Filtered_edge> edges;
for (Simplex_handle sh : Base::skeleton_simplex_range(1)) {
@@ -178,7 +146,7 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
}
for (int iteration = 0; iteration < nb_collapse_iteration; iteration++) {
- edges = Gudhi::collapse::flag_complex_collapse_edges(edges);
+ edges = Gudhi::collapse::flag_complex_collapse_edges(std::move(edges));
}
Simplex_tree_interface* collapsed_stree_ptr = new Simplex_tree_interface();
// Copy the original 0-skeleton
@@ -190,9 +158,13 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
collapsed_stree_ptr->insert({std::get<0>(remaining_edge), std::get<1>(remaining_edge)}, std::get<2>(remaining_edge));
}
return collapsed_stree_ptr;
-#else
- throw std::runtime_error("Unable to collapse edges as it requires Eigen3 >= 3.1.0.");
-#endif
+ }
+
+ void expansion_with_blockers_callback(int dimension, blocker_func_t user_func, void *user_data) {
+ Base::expansion_with_blockers(dimension, [&](Simplex_handle sh){
+ Simplex simplex(Base::simplex_vertex_range(sh).begin(), Base::simplex_vertex_range(sh).end());
+ return user_func(simplex, user_data);
+ });
}
// Iterator over the simplex tree
diff --git a/src/python/pyproject.toml b/src/python/pyproject.toml
index a9fb4985..55b64466 100644
--- a/src/python/pyproject.toml
+++ b/src/python/pyproject.toml
@@ -1,3 +1,3 @@
[build-system]
-requires = ["setuptools", "wheel", "numpy>=1.15.0", "cython", "pybind11"]
+requires = ["setuptools>=24.2.0", "wheel", "numpy>=1.15.0", "cython>=0.27", "pybind11"]
build-backend = "setuptools.build_meta"
diff --git a/src/python/setup.py.in b/src/python/setup.py.in
index 23746998..2c67c2c5 100644
--- a/src/python/setup.py.in
+++ b/src/python/setup.py.in
@@ -5,6 +5,7 @@
Copyright (C) 2019 Inria
Modification(s):
+ - 2021/12 Vincent Rouvreau: Python 3.5 as minimal version
- YYYY/MM Author: Description of the modification
"""
@@ -43,7 +44,7 @@ for module in cython_modules:
include_dirs=include_dirs,
runtime_library_dirs=runtime_library_dirs,))
-ext_modules = cythonize(ext_modules, compiler_directives={'language_level': str(sys.version_info[0])})
+ext_modules = cythonize(ext_modules, compiler_directives={'language_level': '3'})
for module in pybind11_modules:
my_include_dirs = include_dirs + [pybind11.get_include(False), pybind11.get_include(True)]
@@ -86,6 +87,7 @@ setup(
long_description_content_type='text/x-rst',
long_description=long_description,
ext_modules = ext_modules,
+ python_requires='>=3.5.0',
install_requires = ['numpy >= 1.15.0',],
package_data={"": ["*.dll"], },
)
diff --git a/src/python/test/test_alpha_complex.py b/src/python/test/test_alpha_complex.py
index 814f8289..f81e6137 100755
--- a/src/python/test/test_alpha_complex.py
+++ b/src/python/test/test_alpha_complex.py
@@ -8,10 +8,12 @@
- YYYY/MM Author: Description of the modification
"""
-import gudhi as gd
+from gudhi import AlphaComplex
import math
import numpy as np
import pytest
+import warnings
+
try:
# python3
from itertools import zip_longest
@@ -19,22 +21,24 @@ except ImportError:
# python2
from itertools import izip_longest as zip_longest
-__author__ = "Vincent Rouvreau"
-__copyright__ = "Copyright (C) 2016 Inria"
-__license__ = "MIT"
def _empty_alpha(precision):
- alpha_complex = gd.AlphaComplex(points=[[0, 0]], precision = precision)
+ alpha_complex = AlphaComplex(precision = precision)
+ assert alpha_complex.__is_defined() == True
+
+def _one_2d_point_alpha(precision):
+ alpha_complex = AlphaComplex(points=[[0, 0]], precision = precision)
assert alpha_complex.__is_defined() == True
def test_empty_alpha():
for precision in ['fast', 'safe', 'exact']:
_empty_alpha(precision)
+ _one_2d_point_alpha(precision)
def _infinite_alpha(precision):
point_list = [[0, 0], [1, 0], [0, 1], [1, 1]]
- alpha_complex = gd.AlphaComplex(points=point_list, precision = precision)
+ alpha_complex = AlphaComplex(points=point_list, precision = precision)
assert alpha_complex.__is_defined() == True
simplex_tree = alpha_complex.create_simplex_tree()
@@ -69,18 +73,9 @@ def _infinite_alpha(precision):
assert point_list[1] == alpha_complex.get_point(1)
assert point_list[2] == alpha_complex.get_point(2)
assert point_list[3] == alpha_complex.get_point(3)
- try:
- alpha_complex.get_point(4) == []
- except IndexError:
- pass
- else:
- assert False
- try:
- alpha_complex.get_point(125) == []
- except IndexError:
- pass
- else:
- assert False
+
+ with pytest.raises(IndexError):
+ alpha_complex.get_point(len(point_list))
def test_infinite_alpha():
for precision in ['fast', 'safe', 'exact']:
@@ -88,7 +83,7 @@ def test_infinite_alpha():
def _filtered_alpha(precision):
point_list = [[0, 0], [1, 0], [0, 1], [1, 1]]
- filtered_alpha = gd.AlphaComplex(points=point_list, precision = precision)
+ filtered_alpha = AlphaComplex(points=point_list, precision = precision)
simplex_tree = filtered_alpha.create_simplex_tree(max_alpha_square=0.25)
@@ -99,18 +94,9 @@ def _filtered_alpha(precision):
assert point_list[1] == filtered_alpha.get_point(1)
assert point_list[2] == filtered_alpha.get_point(2)
assert point_list[3] == filtered_alpha.get_point(3)
- try:
- filtered_alpha.get_point(4) == []
- except IndexError:
- pass
- else:
- assert False
- try:
- filtered_alpha.get_point(125) == []
- except IndexError:
- pass
- else:
- assert False
+
+ with pytest.raises(IndexError):
+ filtered_alpha.get_point(len(point_list))
assert list(simplex_tree.get_filtration()) == [
([0], 0.0),
@@ -141,10 +127,10 @@ def _safe_alpha_persistence_comparison(precision):
embedding2 = [[signal[i], delayed[i]] for i in range(len(time))]
#build alpha complex and simplex tree
- alpha_complex1 = gd.AlphaComplex(points=embedding1, precision = precision)
+ alpha_complex1 = AlphaComplex(points=embedding1, precision = precision)
simplex_tree1 = alpha_complex1.create_simplex_tree()
- alpha_complex2 = gd.AlphaComplex(points=embedding2, precision = precision)
+ alpha_complex2 = AlphaComplex(points=embedding2, precision = precision)
simplex_tree2 = alpha_complex2.create_simplex_tree()
diag1 = simplex_tree1.persistence()
@@ -162,7 +148,7 @@ def test_safe_alpha_persistence_comparison():
def _delaunay_complex(precision):
point_list = [[0, 0], [1, 0], [0, 1], [1, 1]]
- filtered_alpha = gd.AlphaComplex(points=point_list, precision = precision)
+ filtered_alpha = AlphaComplex(points=point_list, precision = precision)
simplex_tree = filtered_alpha.create_simplex_tree(default_filtration_value = True)
@@ -173,18 +159,11 @@ def _delaunay_complex(precision):
assert point_list[1] == filtered_alpha.get_point(1)
assert point_list[2] == filtered_alpha.get_point(2)
assert point_list[3] == filtered_alpha.get_point(3)
- try:
- filtered_alpha.get_point(4) == []
- except IndexError:
- pass
- else:
- assert False
- try:
- filtered_alpha.get_point(125) == []
- except IndexError:
- pass
- else:
- assert False
+
+ with pytest.raises(IndexError):
+ filtered_alpha.get_point(4)
+ with pytest.raises(IndexError):
+ filtered_alpha.get_point(125)
for filtered_value in simplex_tree.get_filtration():
assert math.isnan(filtered_value[1])
@@ -198,7 +177,13 @@ def test_delaunay_complex():
_delaunay_complex(precision)
def _3d_points_on_a_plane(precision, default_filtration_value):
- alpha = gd.AlphaComplex(off_file='alphacomplexdoc.off', precision = precision)
+ alpha = AlphaComplex(points = [[1.0, 1.0 , 0.0],
+ [7.0, 0.0 , 0.0],
+ [4.0, 6.0 , 0.0],
+ [9.0, 6.0 , 0.0],
+ [0.0, 14.0, 0.0],
+ [2.0, 19.0, 0.0],
+ [9.0, 17.0, 0.0]], precision = precision)
simplex_tree = alpha.create_simplex_tree(default_filtration_value = default_filtration_value)
assert simplex_tree.dimension() == 2
@@ -206,28 +191,16 @@ def _3d_points_on_a_plane(precision, default_filtration_value):
assert simplex_tree.num_simplices() == 25
def test_3d_points_on_a_plane():
- off_file = open("alphacomplexdoc.off", "w")
- off_file.write("OFF \n" \
- "7 0 0 \n" \
- "1.0 1.0 0.0\n" \
- "7.0 0.0 0.0\n" \
- "4.0 6.0 0.0\n" \
- "9.0 6.0 0.0\n" \
- "0.0 14.0 0.0\n" \
- "2.0 19.0 0.0\n" \
- "9.0 17.0 0.0\n" )
- off_file.close()
-
for default_filtration_value in [True, False]:
for precision in ['fast', 'safe', 'exact']:
_3d_points_on_a_plane(precision, default_filtration_value)
def _3d_tetrahedrons(precision):
points = 10*np.random.rand(10, 3)
- alpha = gd.AlphaComplex(points=points, precision = precision)
+ alpha = AlphaComplex(points = points, precision = precision)
st_alpha = alpha.create_simplex_tree(default_filtration_value = False)
# New AlphaComplex for get_point to work
- delaunay = gd.AlphaComplex(points=points, precision = precision)
+ delaunay = AlphaComplex(points = points, precision = precision)
st_delaunay = delaunay.create_simplex_tree(default_filtration_value = True)
delaunay_tetra = []
@@ -256,3 +229,87 @@ def _3d_tetrahedrons(precision):
def test_3d_tetrahedrons():
for precision in ['fast', 'safe', 'exact']:
_3d_tetrahedrons(precision)
+
+def test_off_file_deprecation_warning():
+ off_file = open("alphacomplexdoc.off", "w")
+ off_file.write("OFF \n" \
+ "7 0 0 \n" \
+ "1.0 1.0 0.0\n" \
+ "7.0 0.0 0.0\n" \
+ "4.0 6.0 0.0\n" \
+ "9.0 6.0 0.0\n" \
+ "0.0 14.0 0.0\n" \
+ "2.0 19.0 0.0\n" \
+ "9.0 17.0 0.0\n" )
+ off_file.close()
+
+ with pytest.warns(DeprecationWarning):
+ alpha = AlphaComplex(off_file="alphacomplexdoc.off")
+
+def test_non_existing_off_file():
+ with pytest.warns(DeprecationWarning):
+ with pytest.raises(FileNotFoundError):
+ alpha = AlphaComplex(off_file="pouetpouettralala.toubiloubabdou")
+
+def test_inconsistency_points_and_weights():
+ points = [[1.0, 1.0 , 0.0],
+ [7.0, 0.0 , 0.0],
+ [4.0, 6.0 , 0.0],
+ [9.0, 6.0 , 0.0],
+ [0.0, 14.0, 0.0],
+ [2.0, 19.0, 0.0],
+ [9.0, 17.0, 0.0]]
+ with pytest.raises(ValueError):
+ # 7 points, 8 weights, on purpose
+ alpha = AlphaComplex(points = points,
+ weights = [1., 2., 3., 4., 5., 6., 7., 8.])
+
+ with pytest.raises(ValueError):
+ # 7 points, 6 weights, on purpose
+ alpha = AlphaComplex(points = points,
+ weights = [1., 2., 3., 4., 5., 6.])
+
+def _weighted_doc_example(precision):
+ stree = AlphaComplex(points=[[ 1., -1., -1.],
+ [-1., 1., -1.],
+ [-1., -1., 1.],
+ [ 1., 1., 1.],
+ [ 2., 2., 2.]],
+ weights = [4., 4., 4., 4., 1.],
+ precision = precision).create_simplex_tree()
+
+ assert stree.filtration([0, 1, 2, 3]) == pytest.approx(-1.)
+ assert stree.filtration([0, 1, 3, 4]) == pytest.approx(95.)
+ assert stree.filtration([0, 2, 3, 4]) == pytest.approx(95.)
+ assert stree.filtration([1, 2, 3, 4]) == pytest.approx(95.)
+
+def test_weighted_doc_example():
+ for precision in ['fast', 'safe', 'exact']:
+ _weighted_doc_example(precision)
+
+def test_float_relative_precision():
+ assert AlphaComplex.get_float_relative_precision() == 1e-5
+ # Must be > 0.
+ with pytest.raises(ValueError):
+ AlphaComplex.set_float_relative_precision(0.)
+ # Must be < 1.
+ with pytest.raises(ValueError):
+ AlphaComplex.set_float_relative_precision(1.)
+
+ points = [[1, 1], [7, 0], [4, 6], [9, 6], [0, 14], [2, 19], [9, 17]]
+ st = AlphaComplex(points=points).create_simplex_tree()
+ filtrations = list(st.get_filtration())
+
+ # Get a better precision
+ AlphaComplex.set_float_relative_precision(1e-15)
+ assert AlphaComplex.get_float_relative_precision() == 1e-15
+
+ st = AlphaComplex(points=points).create_simplex_tree()
+ filtrations_better_resolution = list(st.get_filtration())
+
+ assert len(filtrations) == len(filtrations_better_resolution)
+ for idx in range(len(filtrations)):
+ # check simplex is the same
+ assert filtrations[idx][0] == filtrations_better_resolution[idx][0]
+ # check filtration is about the same with a relative precision of the worst case
+ assert filtrations[idx][1] == pytest.approx(filtrations_better_resolution[idx][1], rel=1e-5)
diff --git a/src/python/test/test_betti_curve_representations.py b/src/python/test/test_betti_curve_representations.py
new file mode 100755
index 00000000..6a45da4d
--- /dev/null
+++ b/src/python/test/test_betti_curve_representations.py
@@ -0,0 +1,59 @@
+import numpy as np
+import scipy.interpolate
+import pytest
+
+from gudhi.representations.vector_methods import BettiCurve
+
+def test_betti_curve_is_irregular_betti_curve_followed_by_interpolation():
+ m = 10
+ n = 1000
+ pinf = 0.05
+ pzero = 0.05
+ res = 100
+
+ pds = []
+ for i in range(0, m):
+ pd = np.zeros((n, 2))
+ pd[:, 0] = np.random.uniform(0, 10, n)
+ pd[:, 1] = np.random.uniform(pd[:, 0], 10, n)
+ pd[np.random.uniform(0, 1, n) < pzero, 0] = 0
+ pd[np.random.uniform(0, 1, n) < pinf, 1] = np.inf
+ pds.append(pd)
+
+ bc = BettiCurve(resolution=None, predefined_grid=None)
+ bc.fit(pds)
+ bettis = bc.transform(pds)
+
+ bc2 = BettiCurve(resolution=None, predefined_grid=None)
+ bettis2 = bc2.fit_transform(pds)
+ assert((bc2.grid_ == bc.grid_).all())
+ assert((bettis2 == bettis).all())
+
+ for i in range(0, m):
+ grid = np.linspace(pds[i][np.isfinite(pds[i])].min(), pds[i][np.isfinite(pds[i])].max() + 1, res)
+ bc_gridded = BettiCurve(predefined_grid=grid)
+ bc_gridded.fit([])
+ bettis_gridded = bc_gridded(pds[i])
+
+ interp = scipy.interpolate.interp1d(bc.grid_, bettis[i, :], kind="previous", fill_value="extrapolate")
+ bettis_interp = np.array(interp(grid), dtype=int)
+ assert((bettis_interp == bettis_gridded).all())
+
+
+def test_empty_with_predefined_grid():
+ random_grid = np.sort(np.random.uniform(0, 1, 100))
+ bc = BettiCurve(predefined_grid=random_grid)
+ bettis = bc.fit_transform([])
+ assert((bc.grid_ == random_grid).all())
+ assert((bettis == 0).all())
+
+
+def test_empty():
+ bc = BettiCurve(resolution=None, predefined_grid=None)
+ bettis = bc.fit_transform([])
+ assert(bc.grid_ == [-np.inf])
+ assert((bettis == 0).all())
+
+def test_wrong_value_of_predefined_grid():
+ with pytest.raises(ValueError):
+ BettiCurve(predefined_grid=[1, 2, 3])
diff --git a/src/python/test/test_diff.py b/src/python/test/test_diff.py
new file mode 100644
index 00000000..dca001a9
--- /dev/null
+++ b/src/python/test/test_diff.py
@@ -0,0 +1,78 @@
+from gudhi.tensorflow import *
+import numpy as np
+import tensorflow as tf
+import gudhi as gd
+
+def test_rips_diff():
+
+ Xinit = np.array([[1.,1.],[2.,2.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ rl = RipsLayer(maximum_edge_length=2., homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = rl.call(X)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [X])
+ assert tf.norm(grads[0]-tf.constant([[-.5,-.5],[.5,.5]]),1) <= 1e-6
+
+def test_cubical_diff():
+
+ Xinit = np.array([[0.,2.,2.],[2.,2.,2.],[2.,2.,1.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ cl = CubicalLayer(homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = cl.call(X)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [X])
+ assert tf.norm(grads[0]-tf.constant([[0.,0.,0.],[0.,.5,0.],[0.,0.,-.5]]),1) <= 1e-6
+
+def test_nonsquare_cubical_diff():
+
+ Xinit = np.array([[-1.,1.,0.],[1.,1.,1.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ cl = CubicalLayer(homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = cl.call(X)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [X])
+ assert tf.norm(grads[0]-tf.constant([[0.,0.5,-0.5],[0.,0.,0.]]),1) <= 1e-6
+
+def test_st_diff():
+
+ st = gd.SimplexTree()
+ st.insert([0])
+ st.insert([1])
+ st.insert([2])
+ st.insert([3])
+ st.insert([4])
+ st.insert([5])
+ st.insert([6])
+ st.insert([7])
+ st.insert([8])
+ st.insert([9])
+ st.insert([10])
+ st.insert([0, 1])
+ st.insert([1, 2])
+ st.insert([2, 3])
+ st.insert([3, 4])
+ st.insert([4, 5])
+ st.insert([5, 6])
+ st.insert([6, 7])
+ st.insert([7, 8])
+ st.insert([8, 9])
+ st.insert([9, 10])
+
+ Finit = np.array([6.,4.,3.,4.,5.,4.,3.,2.,3.,4.,5.], dtype=np.float32)
+ F = tf.Variable(initial_value=Finit, trainable=True)
+ sl = LowerStarSimplexTreeLayer(simplextree=st, homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = sl.call(F)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [F])
+
+ assert tf.math.reduce_all(tf.math.equal(grads[0].indices, tf.constant([2,4])))
+ assert tf.math.reduce_all(tf.math.equal(grads[0].values, tf.constant([-1.,1.])))
+
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index e46d616c..b276f041 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -91,11 +91,11 @@ def test_density():
def test_dtm_overflow_warnings():
pts = numpy.array([[10., 100000000000000000000000000000.], [1000., 100000000000000000000000000.]])
-
- with warnings.catch_warnings(record=True) as w:
- # TODO Test "keops" implementation as well when next version of pykeops (current is 1.5) is released (should fix the problem (cf. issue #543))
- dtm = DistanceToMeasure(2, implementation="hnsw")
- r = dtm.fit_transform(pts)
- assert len(w) == 1
- assert issubclass(w[0].category, RuntimeWarning)
- assert "Overflow" in str(w[0].message)
+ impl_warn = ["keops", "hnsw"]
+ for impl in impl_warn:
+ with warnings.catch_warnings(record=True) as w:
+ dtm = DistanceToMeasure(2, implementation=impl)
+ r = dtm.fit_transform(pts)
+ assert len(w) == 1
+ assert issubclass(w[0].category, RuntimeWarning)
+ assert "Overflow" in str(w[0].message)
diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py
new file mode 100644
index 00000000..c19836b7
--- /dev/null
+++ b/src/python/test/test_persistence_graphical_tools.py
@@ -0,0 +1,121 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+import gudhi as gd
+import numpy as np
+import matplotlib as plt
+import pytest
+
+
+def test_array_handler():
+ diags = np.array([[1, 2], [3, 4], [5, 6]], float)
+ arr_diags = gd.persistence_graphical_tools._array_handler(diags)
+ for idx in range(len(diags)):
+ assert arr_diags[idx][0] == 0
+ np.testing.assert_array_equal(arr_diags[idx][1], diags[idx])
+
+ diags = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]
+ arr_diags = gd.persistence_graphical_tools._array_handler(diags)
+ for idx in range(len(diags)):
+ assert arr_diags[idx][0] == 0
+ assert arr_diags[idx][1] == diags[idx]
+
+ diags = [(0, (1.0, 2.0)), (0, (3.0, 4.0)), (0, (5.0, 6.0))]
+ assert gd.persistence_graphical_tools._array_handler(diags) == diags
+
+
+def test_min_birth_max_death():
+ diags = [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ (0, (0.0, 0.118398)),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.117908)),
+ (0, (0.0, 0.112307)),
+ (0, (0.0, 0.107535)),
+ (0, (0.0, 0.106382)),
+ ]
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0.0, 1.0)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (0.0, 5.0)
+
+
+def test_limit_min_birth_max_death():
+ diags = [
+ (0, (2.0, float("inf"))),
+ (0, (2.0, float("inf"))),
+ ]
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2.0, 3.0)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (2.0, 6.0)
+
+
+def test_limit_to_max_intervals():
+ diags = [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ (0, (0.0, 0.118398)),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.117908)),
+ (0, (0.0, 0.112307)),
+ (0, (0.0, 0.107535)),
+ (0, (0.0, 0.106382)),
+ ]
+ # check no warnings if max_intervals equals to the diagrams number
+ with pytest.warns(None) as record:
+ truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(
+ diags, 10, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ # check diagrams are not sorted
+ assert truncated_diags == diags
+ assert len(record) == 0
+
+ # check warning if max_intervals lower than the diagrams number
+ with pytest.warns(UserWarning) as record:
+ truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(
+ diags, 5, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ # check diagrams are truncated and sorted by life time
+ assert truncated_diags == [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ ]
+ assert len(record) == 1
+
+
+def _limit_plot_persistence(function):
+ pplot = function(persistence=[])
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[], legend=True)
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[(0, float("inf"))])
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[(0, float("inf"))], legend=True)
+ assert isinstance(pplot, plt.axes.SubplotBase)
+
+
+def test_limit_plot_persistence():
+ for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]:
+ _limit_plot_persistence(function)
+
+
+def _non_existing_persistence_file(function):
+ with pytest.raises(FileNotFoundError):
+ function(persistence_file="pouetpouettralala.toubiloubabdou")
+
+
+def test_non_existing_persistence_file():
+ for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]:
+ _non_existing_persistence_file(function)
diff --git a/src/python/test/test_reader_utils.py b/src/python/test/test_reader_utils.py
index e96e0569..fdfddc4b 100755
--- a/src/python/test/test_reader_utils.py
+++ b/src/python/test/test_reader_utils.py
@@ -8,8 +8,9 @@
- YYYY/MM Author: Description of the modification
"""
-import gudhi
+import gudhi as gd
import numpy as np
+from pytest import raises
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2017 Inria"
@@ -18,7 +19,7 @@ __license__ = "MIT"
def test_non_existing_csv_file():
# Try to open a non existing file
- matrix = gudhi.read_lower_triangular_matrix_from_csv_file(
+ matrix = gd.read_lower_triangular_matrix_from_csv_file(
csv_file="pouetpouettralala.toubiloubabdou"
)
assert matrix == []
@@ -29,7 +30,7 @@ def test_full_square_distance_matrix_csv_file():
test_file = open("full_square_distance_matrix.csv", "w")
test_file.write("0;1;2;3;\n1;0;4;5;\n2;4;0;6;\n3;5;6;0;")
test_file.close()
- matrix = gudhi.read_lower_triangular_matrix_from_csv_file(
+ matrix = gd.read_lower_triangular_matrix_from_csv_file(
csv_file="full_square_distance_matrix.csv", separator=";"
)
assert matrix == [[], [1.0], [2.0, 4.0], [3.0, 5.0, 6.0]]
@@ -40,7 +41,7 @@ def test_lower_triangular_distance_matrix_csv_file():
test_file = open("lower_triangular_distance_matrix.csv", "w")
test_file.write("\n1,\n2,3,\n4,5,6,\n7,8,9,10,")
test_file.close()
- matrix = gudhi.read_lower_triangular_matrix_from_csv_file(
+ matrix = gd.read_lower_triangular_matrix_from_csv_file(
csv_file="lower_triangular_distance_matrix.csv", separator=","
)
assert matrix == [[], [1.0], [2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]]
@@ -48,11 +49,11 @@ def test_lower_triangular_distance_matrix_csv_file():
def test_non_existing_persistence_file():
# Try to open a non existing file
- persistence = gudhi.read_persistence_intervals_grouped_by_dimension(
+ persistence = gd.read_persistence_intervals_grouped_by_dimension(
persistence_file="pouetpouettralala.toubiloubabdou"
)
assert persistence == []
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="pouetpouettralala.toubiloubabdou", only_this_dim=1
)
np.testing.assert_array_equal(persistence, [])
@@ -65,21 +66,21 @@ def test_read_persistence_intervals_without_dimension():
"# Simple persistence diagram without dimension\n2.7 3.7\n9.6 14.\n34.2 34.974\n3. inf"
)
test_file.close()
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_without_dimension.pers"
)
np.testing.assert_array_equal(
persistence, [(2.7, 3.7), (9.6, 14.0), (34.2, 34.974), (3.0, float("Inf"))]
)
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_without_dimension.pers", only_this_dim=0
)
np.testing.assert_array_equal(persistence, [])
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_without_dimension.pers", only_this_dim=1
)
np.testing.assert_array_equal(persistence, [])
- persistence = gudhi.read_persistence_intervals_grouped_by_dimension(
+ persistence = gd.read_persistence_intervals_grouped_by_dimension(
persistence_file="persistence_intervals_without_dimension.pers"
)
assert persistence == {
@@ -94,29 +95,29 @@ def test_read_persistence_intervals_with_dimension():
"# Simple persistence diagram with dimension\n0 2.7 3.7\n1 9.6 14.\n3 34.2 34.974\n1 3. inf"
)
test_file.close()
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers"
)
np.testing.assert_array_equal(
persistence, [(2.7, 3.7), (9.6, 14.0), (34.2, 34.974), (3.0, float("Inf"))]
)
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers", only_this_dim=0
)
np.testing.assert_array_equal(persistence, [(2.7, 3.7)])
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers", only_this_dim=1
)
np.testing.assert_array_equal(persistence, [(9.6, 14.0), (3.0, float("Inf"))])
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers", only_this_dim=2
)
np.testing.assert_array_equal(persistence, [])
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers", only_this_dim=3
)
np.testing.assert_array_equal(persistence, [(34.2, 34.974)])
- persistence = gudhi.read_persistence_intervals_grouped_by_dimension(
+ persistence = gd.read_persistence_intervals_grouped_by_dimension(
persistence_file="persistence_intervals_with_dimension.pers"
)
assert persistence == {
diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py
new file mode 100644
index 00000000..e5d2de82
--- /dev/null
+++ b/src/python/test/test_remote_datasets.py
@@ -0,0 +1,87 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Hind Montassif
+#
+# Copyright (C) 2021 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+from gudhi.datasets import remote
+
+import shutil
+import io
+import sys
+import pytest
+
+from os.path import isdir, expanduser, exists
+from os import remove, environ
+
+def test_data_home():
+ # Test _get_data_home and clear_data_home on new empty folder
+ empty_data_home = remote._get_data_home(data_home="empty_folder_for_test")
+ assert isdir(empty_data_home)
+
+ remote.clear_data_home(data_home=empty_data_home)
+ assert not isdir(empty_data_home)
+
+def test_fetch_remote():
+ # Test fetch with a wrong checksum
+ with pytest.raises(OSError):
+ remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "tmp_spiral_2d.npy", file_checksum = 'XXXXXXXXXX')
+ assert not exists("tmp_spiral_2d.npy")
+
+def _get_bunny_license_print(accept_license = False):
+ capturedOutput = io.StringIO()
+ # Redirect stdout
+ sys.stdout = capturedOutput
+
+ bunny_arr = remote.fetch_bunny("./tmp_for_test/bunny.npy", accept_license)
+ assert bunny_arr.shape == (35947, 3)
+ del bunny_arr
+ remove("./tmp_for_test/bunny.npy")
+
+ # Reset redirect
+ sys.stdout = sys.__stdout__
+ return capturedOutput
+
+def test_print_bunny_license():
+ # Test not printing bunny.npy LICENSE when accept_license = True
+ assert "" == _get_bunny_license_print(accept_license = True).getvalue()
+ # Test printing bunny.LICENSE file when fetching bunny.npy with accept_license = False (default)
+ with open("./tmp_for_test/bunny.LICENSE") as f:
+ assert f.read().rstrip("\n") == _get_bunny_license_print().getvalue().rstrip("\n")
+ shutil.rmtree("./tmp_for_test")
+
+def test_fetch_remote_datasets_wrapped():
+ # Test fetch_spiral_2d and fetch_bunny wrapping functions with data directory different from default (twice, to test case of already fetched files)
+ # Default case is not tested because it would fail in case the user sets the 'GUDHI_DATA' environment variable locally
+ for i in range(2):
+ spiral_2d_arr = remote.fetch_spiral_2d("./another_fetch_folder_for_test/spiral_2d.npy")
+ assert spiral_2d_arr.shape == (114562, 2)
+
+ bunny_arr = remote.fetch_bunny("./another_fetch_folder_for_test/bunny.npy")
+ assert bunny_arr.shape == (35947, 3)
+
+ # Check that the directory was created
+ assert isdir("./another_fetch_folder_for_test")
+ # Check downloaded files
+ assert exists("./another_fetch_folder_for_test/spiral_2d.npy")
+ assert exists("./another_fetch_folder_for_test/bunny.npy")
+ assert exists("./another_fetch_folder_for_test/bunny.LICENSE")
+
+ # Remove test folders
+ del spiral_2d_arr
+ del bunny_arr
+ shutil.rmtree("./another_fetch_folder_for_test")
+
+def test_gudhi_data_env():
+ # Set environment variable "GUDHI_DATA"
+ environ["GUDHI_DATA"] = "./test_folder_from_env_var"
+ bunny_arr = remote.fetch_bunny()
+ assert bunny_arr.shape == (35947, 3)
+ assert exists("./test_folder_from_env_var/points/bunny/bunny.npy")
+ assert exists("./test_folder_from_env_var/points/bunny/bunny.LICENSE")
+ # Remove test folder
+ del bunny_arr
+ shutil.rmtree("./test_folder_from_env_var")
diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py
index 93461f1e..4a455bb6 100755
--- a/src/python/test/test_representations.py
+++ b/src/python/test/test_representations.py
@@ -105,7 +105,6 @@ def test_dummy_atol():
from gudhi.representations.vector_methods import BettiCurve
-
def test_infinity():
a = np.array([[1.0, 8.0], [2.0, np.inf], [3.0, 4.0]])
c = BettiCurve(20, [0.0, 10.0])(a)
@@ -113,7 +112,6 @@ def test_infinity():
assert c[7] == 3
assert c[9] == 2
-
def test_preprocessing_empty_diagrams():
empty_diag = np.empty(shape = [0, 2])
assert not np.any(BirthPersistenceTransform()(empty_diag))
@@ -154,7 +152,26 @@ def test_vectorization_empty_diagrams():
scv = Entropy(mode="vector", normalized=False, resolution=random_resolution)(empty_diag)
assert not np.any(scv)
assert scv.shape[0] == random_resolution
-
+
+def test_entropy_miscalculation():
+ diag_ex = np.array([[0.0,1.0], [0.0,1.0], [0.0,2.0]])
+ def pe(pd):
+ l = pd[:,1] - pd[:,0]
+ l = l/sum(l)
+ return -np.dot(l, np.log(l))
+ sce = Entropy(mode="scalar")
+ assert [[pe(diag_ex)]] == sce.fit_transform([diag_ex])
+ sce = Entropy(mode="vector", resolution=4, normalized=False)
+ pef = [-1/4*np.log(1/4)-1/4*np.log(1/4)-1/2*np.log(1/2),
+ -1/4*np.log(1/4)-1/4*np.log(1/4)-1/2*np.log(1/2),
+ -1/2*np.log(1/2),
+ 0.0]
+ assert all(([pef] == sce.fit_transform([diag_ex]))[0])
+ sce = Entropy(mode="vector", resolution=4, normalized=True)
+ pefN = (sce.fit_transform([diag_ex]))[0]
+ area = np.linalg.norm(pefN, ord=1)
+ assert area==1
+
def test_kernel_empty_diagrams():
empty_diag = np.empty(shape = [0, 2])
assert SlicedWassersteinDistance(num_directions=100)(empty_diag, empty_diag) == 0.
@@ -169,3 +186,4 @@ def test_kernel_empty_diagrams():
# PersistenceScaleSpaceKernel(kernel_approx=RBFSampler(gamma=1./2, n_components=100000).fit(np.ones([1,2])))(empty_diag, empty_diag)
# PersistenceFisherKernel(bandwidth_fisher=1., bandwidth=1.)(empty_diag, empty_diag)
# PersistenceFisherKernel(bandwidth_fisher=1., bandwidth=1., kernel_approx=RBFSampler(gamma=1./2, n_components=100000).fit(np.ones([1,2])))(empty_diag, empty_diag)
+
diff --git a/src/python/test/test_representations_preprocessing.py b/src/python/test/test_representations_preprocessing.py
new file mode 100644
index 00000000..838cf30c
--- /dev/null
+++ b/src/python/test/test_representations_preprocessing.py
@@ -0,0 +1,39 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.representations.preprocessing import DimensionSelector
+import numpy as np
+import pytest
+
+H0_0 = np.array([0.0, 0.0])
+H1_0 = np.array([1.0, 0.0])
+H0_1 = np.array([0.0, 1.0])
+H1_1 = np.array([1.0, 1.0])
+H0_2 = np.array([0.0, 2.0])
+H1_2 = np.array([1.0, 2.0])
+
+
+def test_dimension_selector():
+ X = [[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]]
+ ds = DimensionSelector(index=0)
+ h0 = ds.fit_transform(X)
+ np.testing.assert_array_equal(h0[0], H0_0)
+ np.testing.assert_array_equal(h0[1], H0_1)
+ np.testing.assert_array_equal(h0[2], H0_2)
+
+ ds = DimensionSelector(index=1)
+ h1 = ds.fit_transform(X)
+ np.testing.assert_array_equal(h1[0], H1_0)
+ np.testing.assert_array_equal(h1[1], H1_1)
+ np.testing.assert_array_equal(h1[2], H1_2)
+
+ ds = DimensionSelector(index=2)
+ with pytest.raises(IndexError):
+ h2 = ds.fit_transform([[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]])
diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py
index 31c46213..54bafed5 100755
--- a/src/python/test/test_simplex_tree.py
+++ b/src/python/test/test_simplex_tree.py
@@ -8,7 +8,7 @@
- YYYY/MM Author: Description of the modification
"""
-from gudhi import SimplexTree, __GUDHI_USE_EIGEN3
+from gudhi import SimplexTree
import numpy as np
import pytest
@@ -320,6 +320,10 @@ def test_extend_filtration():
]
dgms = st.extended_persistence(min_persistence=-1.)
+ assert len(dgms) == 4
+ # Sort by (death-birth) descending - we are only interested in those with the longest life span
+ for idx in range(4):
+ dgms[idx] = sorted(dgms[idx], key=lambda x:(-abs(x[1][0]-x[1][1])))
assert dgms[0][0][1][0] == pytest.approx(2.)
assert dgms[0][0][1][1] == pytest.approx(3.)
@@ -354,16 +358,11 @@ def test_collapse_edges():
assert st.num_simplices() == 10
- if __GUDHI_USE_EIGEN3:
- st.collapse_edges()
- assert st.num_simplices() == 9
- assert st.find([1, 3]) == False
- for simplex in st.get_skeleton(0):
- assert simplex[1] == 1.
- else:
- # If no Eigen3, collapse_edges throws an exception
- with pytest.raises(RuntimeError):
- st.collapse_edges()
+ st.collapse_edges()
+ assert st.num_simplices() == 9
+ assert st.find([0, 2]) == False # [1, 3] would be fine as well
+ for simplex in st.get_skeleton(0):
+ assert simplex[1] == 1.
def test_reset_filtration():
st = SimplexTree()
@@ -447,4 +446,116 @@ def test_persistence_intervals_in_dimension():
assert np.array_equal(H2, np.array([[ 0., float("inf")]]))
# Test empty case
assert st.persistence_intervals_in_dimension(3).shape == (0, 2)
- \ No newline at end of file
+
+def test_equality_operator():
+ st1 = SimplexTree()
+ st2 = SimplexTree()
+
+ assert st1 == st2
+
+ st1.insert([1,2,3], 4.)
+ assert st1 != st2
+
+ st2.insert([1,2,3], 4.)
+ assert st1 == st2
+
+def test_simplex_tree_deep_copy():
+ st = SimplexTree()
+ st.insert([1, 2, 3], 0.)
+ # compute persistence only on the original
+ st.compute_persistence()
+
+ st_copy = st.copy()
+ assert st_copy == st
+ st_filt_list = list(st.get_filtration())
+
+ # check persistence is not copied
+ assert st.__is_persistence_defined() == True
+ assert st_copy.__is_persistence_defined() == False
+
+ # remove something in the copy and check the copy is included in the original
+ st_copy.remove_maximal_simplex([1, 2, 3])
+ a_filt_list = list(st_copy.get_filtration())
+ assert len(a_filt_list) < len(st_filt_list)
+
+ for a_splx in a_filt_list:
+ assert a_splx in st_filt_list
+
+ # test double free
+ del st
+ del st_copy
+
+def test_simplex_tree_deep_copy_constructor():
+ st = SimplexTree()
+ st.insert([1, 2, 3], 0.)
+ # compute persistence only on the original
+ st.compute_persistence()
+
+ st_copy = SimplexTree(st)
+ assert st_copy == st
+ st_filt_list = list(st.get_filtration())
+
+ # check persistence is not copied
+ assert st.__is_persistence_defined() == True
+ assert st_copy.__is_persistence_defined() == False
+
+ # remove something in the copy and check the copy is included in the original
+ st_copy.remove_maximal_simplex([1, 2, 3])
+ a_filt_list = list(st_copy.get_filtration())
+ assert len(a_filt_list) < len(st_filt_list)
+
+ for a_splx in a_filt_list:
+ assert a_splx in st_filt_list
+
+ # test double free
+ del st
+ del st_copy
+
+def test_simplex_tree_constructor_exception():
+ with pytest.raises(TypeError):
+ st = SimplexTree(other = "Construction from a string shall raise an exception")
+
+def test_expansion_with_blocker():
+ st=SimplexTree()
+ st.insert([0,1],0)
+ st.insert([0,2],1)
+ st.insert([0,3],2)
+ st.insert([1,2],3)
+ st.insert([1,3],4)
+ st.insert([2,3],5)
+ st.insert([2,4],6)
+ st.insert([3,6],7)
+ st.insert([4,5],8)
+ st.insert([4,6],9)
+ st.insert([5,6],10)
+ st.insert([6],10)
+
+ def blocker(simplex):
+ try:
+ # Block all simplices that contain vertex 6
+ simplex.index(6)
+ print(simplex, ' is blocked')
+ return True
+ except ValueError:
+ print(simplex, ' is accepted')
+ st.assign_filtration(simplex, st.filtration(simplex) + 1.)
+ return False
+
+ st.expansion_with_blocker(2, blocker)
+ assert st.num_simplices() == 22
+ assert st.dimension() == 2
+ assert st.find([4,5,6]) == False
+ assert st.filtration([0,1,2]) == 4.
+ assert st.filtration([0,1,3]) == 5.
+ assert st.filtration([0,2,3]) == 6.
+ assert st.filtration([1,2,3]) == 6.
+
+ st.expansion_with_blocker(3, blocker)
+ assert st.num_simplices() == 23
+ assert st.dimension() == 3
+ assert st.find([4,5,6]) == False
+ assert st.filtration([0,1,2]) == 4.
+ assert st.filtration([0,1,3]) == 5.
+ assert st.filtration([0,2,3]) == 6.
+ assert st.filtration([1,2,3]) == 6.
+ assert st.filtration([0,1,2,3]) == 7.
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
new file mode 100644
index 00000000..1c05a215
--- /dev/null
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -0,0 +1,59 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.sklearn.cubical_persistence import CubicalPersistence
+import numpy as np
+from sklearn import datasets
+
+CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0.0, 6.0], [0.0, 8.0], [0.0, np.inf]])
+
+
+def test_simple_constructor_from_top_cells():
+ cells = datasets.load_digits().images[0]
+ cp = CubicalPersistence(homology_dimensions=0)
+ np.testing.assert_array_equal(cp._CubicalPersistence__transform_only_this_dim(cells), CUBICAL_PERSISTENCE_H0_IMG0)
+ cp = CubicalPersistence(homology_dimensions=[0, 2])
+ diags = cp._CubicalPersistence__transform(cells)
+ assert len(diags) == 2
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+
+def test_simple_constructor_from_top_cells_list():
+ digits = datasets.load_digits().images[:10]
+ cp = CubicalPersistence(homology_dimensions=0, n_jobs=-2)
+
+ diags = cp.fit_transform(digits)
+ assert len(diags) == 10
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+ cp = CubicalPersistence(homology_dimensions=[0, 1], n_jobs=-1)
+ diagsH0H1 = cp.fit_transform(digits)
+ assert len(diagsH0H1) == 10
+ for idx in range(10):
+ np.testing.assert_array_equal(diags[idx], diagsH0H1[idx][0])
+
+def test_simple_constructor_from_flattened_cells():
+ cells = datasets.load_digits().images[0]
+ # Not squared (extended) flatten cells
+ flat_cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten()
+
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
+ diags = cp.fit_transform([flat_cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+ # Not squared (extended) non-flatten cells
+ cells = np.hstack((cells, np.zeros((cells.shape[0], 2))))
+
+ # The aim of this second part of the test is to resize even if not mandatory
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
+ diags = cp.fit_transform([cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
diff --git a/src/python/test/test_subsampling.py b/src/python/test/test_subsampling.py
index 4019852e..3431f372 100755
--- a/src/python/test/test_subsampling.py
+++ b/src/python/test/test_subsampling.py
@@ -91,7 +91,7 @@ def test_simple_choose_n_farthest_points_randomed():
assert gudhi.choose_n_farthest_points(points=[], nb_points=1) == []
assert gudhi.choose_n_farthest_points(points=point_set, nb_points=0) == []
- # Go furter than point set on purpose
+ # Go further than point set on purpose
for iter in range(1, 10):
sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=iter)
for sub in sub_set:
@@ -117,7 +117,7 @@ def test_simple_pick_n_random_points():
assert gudhi.pick_n_random_points(points=[], nb_points=1) == []
assert gudhi.pick_n_random_points(points=point_set, nb_points=0) == []
- # Go furter than point set on purpose
+ # Go further than point set on purpose
for iter in range(1, 10):
sub_set = gudhi.pick_n_random_points(points=point_set, nb_points=iter)
for sub in sub_set: