summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
Diffstat (limited to 'src/python')
-rw-r--r--src/python/CMakeLists.txt368
-rw-r--r--src/python/doc/_templates/layout.html1
-rw-r--r--src/python/doc/alpha_complex_ref.rst1
-rw-r--r--src/python/doc/alpha_complex_sum.inc24
-rw-r--r--src/python/doc/alpha_complex_user.rst112
-rw-r--r--src/python/doc/clustering.rst5
-rwxr-xr-xsrc/python/doc/conf.py5
-rw-r--r--src/python/doc/cubical_complex_sklearn_itf_ref.rst102
-rw-r--r--src/python/doc/cubical_complex_sum.inc30
-rw-r--r--src/python/doc/cubical_complex_tflow_itf_ref.rst40
-rw-r--r--src/python/doc/cubical_complex_user.rst11
-rw-r--r--src/python/doc/datasets.inc14
-rw-r--r--src/python/doc/datasets.rst133
-rw-r--r--src/python/doc/differentiation_sum.inc12
-rw-r--r--src/python/doc/examples.rst1
-rw-r--r--src/python/doc/img/bunny.pngbin0 -> 48040 bytes
-rw-r--r--src/python/doc/img/sklearn.pngbin0 -> 9368 bytes
-rw-r--r--src/python/doc/img/sphere_3d.pngbin0 -> 529148 bytes
-rw-r--r--src/python/doc/img/spiral_2d.pngbin0 -> 279276 bytes
-rw-r--r--src/python/doc/img/tensorflow.pngbin0 -> 3846 bytes
-rw-r--r--src/python/doc/index.rst5
-rw-r--r--src/python/doc/installation.rst140
-rw-r--r--src/python/doc/ls_simplex_tree_tflow_itf_ref.rst53
-rw-r--r--src/python/doc/nerve_gic_complex_user.rst2
-rw-r--r--src/python/doc/persistence_graphical_tools_user.rst2
-rw-r--r--src/python/doc/persistent_cohomology_user.rst29
-rw-r--r--src/python/doc/point_cloud.rst5
-rw-r--r--src/python/doc/representations.rst119
-rw-r--r--src/python/doc/representations_sum.inc24
-rw-r--r--src/python/doc/rips_complex_sum.inc7
-rw-r--r--src/python/doc/rips_complex_tflow_itf_ref.rst48
-rw-r--r--src/python/doc/rips_complex_user.rst135
-rw-r--r--src/python/doc/simplex_tree_sum.inc25
-rw-r--r--src/python/doc/wasserstein_distance_user.rst29
-rwxr-xr-xsrc/python/example/alpha_complex_diagram_persistence_from_off_file_example.py60
-rw-r--r--src/python/example/alpha_complex_from_generated_points_on_sphere_example.py35
-rwxr-xr-xsrc/python/example/alpha_complex_from_points_example.py2
-rwxr-xr-xsrc/python/example/alpha_rips_persistence_bottleneck_distance.py110
-rwxr-xr-xsrc/python/example/plot_alpha_complex.py5
-rwxr-xr-xsrc/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py2
-rwxr-xr-xsrc/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py5
-rw-r--r--src/python/gudhi/__init__.py.in4
-rw-r--r--src/python/gudhi/alpha_complex.pyx91
-rw-r--r--src/python/gudhi/bottleneck.cc18
-rw-r--r--src/python/gudhi/clustering/tomato.py4
-rw-r--r--src/python/gudhi/cubical_complex.pyx12
-rw-r--r--src/python/gudhi/datasets/__init__.py0
-rw-r--r--src/python/gudhi/datasets/generators/__init__.py0
-rw-r--r--src/python/gudhi/datasets/generators/_points.cc121
-rw-r--r--src/python/gudhi/datasets/generators/points.py59
-rw-r--r--src/python/gudhi/datasets/remote.py223
-rw-r--r--src/python/gudhi/hera/bottleneck.cc11
-rw-r--r--src/python/gudhi/hera/wasserstein.cc137
-rw-r--r--src/python/gudhi/off_utils.pyx (renamed from src/python/gudhi/off_reader.pyx)23
-rw-r--r--src/python/gudhi/periodic_cubical_complex.pyx12
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py353
-rw-r--r--src/python/gudhi/point_cloud/knn.py16
-rw-r--r--src/python/gudhi/representations/preprocessing.py57
-rw-r--r--src/python/gudhi/representations/vector_methods.py438
-rw-r--r--src/python/gudhi/rips_complex.pyx17
-rw-r--r--src/python/gudhi/simplex_tree.pxd15
-rw-r--r--src/python/gudhi/simplex_tree.pyx227
-rw-r--r--src/python/gudhi/sklearn/__init__.py0
-rw-r--r--src/python/gudhi/sklearn/cubical_persistence.py110
-rw-r--r--src/python/gudhi/tensorflow/__init__.py5
-rw-r--r--src/python/gudhi/tensorflow/cubical_layer.py82
-rw-r--r--src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py87
-rw-r--r--src/python/gudhi/tensorflow/perslay.py284
-rw-r--r--src/python/gudhi/tensorflow/rips_layer.py93
-rw-r--r--src/python/gudhi/wasserstein/barycenter.py6
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py222
-rw-r--r--src/python/gudhi/weighted_rips_complex.py6
-rw-r--r--src/python/include/Alpha_complex_factory.h118
-rw-r--r--src/python/include/Alpha_complex_interface.h62
-rw-r--r--src/python/include/Persistent_cohomology_interface.h40
-rw-r--r--src/python/include/Simplex_tree_interface.h72
-rw-r--r--src/python/include/pybind11_diagram_utils.h25
-rw-r--r--src/python/pyproject.toml3
-rw-r--r--src/python/setup.py.in17
-rwxr-xr-xsrc/python/test/test_alpha_complex.py179
-rwxr-xr-xsrc/python/test/test_betti_curve_representations.py59
-rwxr-xr-xsrc/python/test/test_cubical_complex.py25
-rwxr-xr-xsrc/python/test/test_datasets_generators.py39
-rw-r--r--src/python/test/test_diff.py78
-rwxr-xr-xsrc/python/test/test_dtm.py12
-rw-r--r--src/python/test/test_off.py21
-rw-r--r--src/python/test/test_persistence_graphical_tools.py122
-rw-r--r--src/python/test/test_perslay.py147
-rwxr-xr-xsrc/python/test/test_reader_utils.py35
-rw-r--r--src/python/test/test_remote_datasets.py87
-rwxr-xr-xsrc/python/test/test_representations.py197
-rw-r--r--src/python/test/test_representations_preprocessing.py39
-rwxr-xr-xsrc/python/test/test_rips_complex.py21
-rwxr-xr-xsrc/python/test/test_simplex_generators.py2
-rwxr-xr-xsrc/python/test/test_simplex_tree.py422
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py59
-rwxr-xr-xsrc/python/test/test_subsampling.py107
-rwxr-xr-xsrc/python/test/test_tomato.py2
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py134
99 files changed, 5131 insertions, 1427 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 5c1402a6..74d1c4c6 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -14,13 +14,16 @@ function( add_GUDHI_PYTHON_lib THE_LIB )
endif(EXISTS ${THE_LIB})
endfunction( add_GUDHI_PYTHON_lib )
-function( add_GUDHI_PYTHON_lib_dir THE_LIB_DIR )
- # deals when it is not set - error on windows
- if(EXISTS ${THE_LIB_DIR})
- set(GUDHI_PYTHON_LIBRARY_DIRS "${GUDHI_PYTHON_LIBRARY_DIRS}'${THE_LIB_DIR}', " PARENT_SCOPE)
- else()
- message("add_GUDHI_PYTHON_lib_dir - '${THE_LIB_DIR}' does not exist")
- endif()
+function( add_GUDHI_PYTHON_lib_dir)
+ # Argument may be a list (specifically on windows with release/debug paths)
+ foreach(THE_LIB_DIR IN LISTS ARGN)
+ # deals when it is not set - error on windows
+ if(EXISTS ${THE_LIB_DIR})
+ set(GUDHI_PYTHON_LIBRARY_DIRS "${GUDHI_PYTHON_LIBRARY_DIRS}'${THE_LIB_DIR}', " PARENT_SCOPE)
+ else()
+ message("add_GUDHI_PYTHON_lib_dir - '${THE_LIB_DIR}' does not exist")
+ endif()
+ endforeach()
endfunction( add_GUDHI_PYTHON_lib_dir )
# THE_TEST is the python test file name (without .py extension) containing tests functions
@@ -41,14 +44,16 @@ function( add_gudhi_debug_info DEBUG_INFO )
endfunction( add_gudhi_debug_info )
if(PYTHONINTERP_FOUND)
- if(PYBIND11_FOUND)
+ if(NUMPY_FOUND AND PYBIND11_FOUND AND CYTHON_FOUND)
add_gudhi_debug_info("Pybind11 version ${PYBIND11_VERSION}")
+ # PyBind11 modules
set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'bottleneck', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'hera', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'clustering', ")
- endif()
- if(CYTHON_FOUND)
- set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'off_reader', ")
+ set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'datasets', ")
+
+ # Cython modules
+ set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'off_utils', ")
set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'simplex_tree', ")
set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'rips_complex', ")
set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'cubical_complex', ")
@@ -65,6 +70,7 @@ if(PYTHONINTERP_FOUND)
set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'euclidean_strong_witness_complex', ")
# Modules that should not be auto-imported in __init__.py
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'representations', ")
+ set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'tensorflow', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'wasserstein', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'point_cloud', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'weighted_rips_complex', ")
@@ -106,6 +112,16 @@ if(PYTHONINTERP_FOUND)
if(TENSORFLOW_FOUND)
add_gudhi_debug_info("TensorFlow version ${TENSORFLOW_VERSION}")
endif()
+ if(SPHINX_FOUND)
+ add_gudhi_debug_info("Sphinx version ${SPHINX_VERSION}")
+ endif()
+ if(SPHINX_PARAMLINKS_FOUND)
+ add_gudhi_debug_info("Sphinx-paramlinks version ${SPHINX_PARAMLINKS_VERSION}")
+ endif()
+ if(PYTHON_DOCS_THEME_FOUND)
+ # Does not have a version number...
+ add_gudhi_debug_info("python_docs_theme found")
+ endif()
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DBOOST_RESULT_OF_USE_DECLTYPE', ")
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DBOOST_ALL_NO_LIB', ")
@@ -113,9 +129,10 @@ if(PYTHONINTERP_FOUND)
# Gudhi and CGAL compilation option
if(MSVC)
+ set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'/std:c++17', ")
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'/fp:strict', ")
else(MSVC)
- set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-std=c++14', ")
+ set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-std=c++17', ")
endif(MSVC)
if(CMAKE_COMPILER_IS_GNUCXX)
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-frounding-math', ")
@@ -133,13 +150,9 @@ if(PYTHONINTERP_FOUND)
add_gudhi_debug_info("Eigen3 version ${EIGEN3_VERSION}")
# No problem, even if no CGAL found
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DCGAL_EIGEN3_ENABLED', ")
- set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DGUDHI_USE_EIGEN3', ")
- set(GUDHI_USE_EIGEN3 "True")
- else (EIGEN3_FOUND)
- set(GUDHI_USE_EIGEN3 "False")
endif (EIGEN3_FOUND)
- set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'off_reader', ")
+ set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'off_utils', ")
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'simplex_tree', ")
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'rips_complex', ")
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'cubical_complex', ")
@@ -150,19 +163,36 @@ if(PYTHONINTERP_FOUND)
set(GUDHI_PYBIND11_MODULES "${GUDHI_PYBIND11_MODULES}'clustering/_tomato', ")
set(GUDHI_PYBIND11_MODULES "${GUDHI_PYBIND11_MODULES}'hera/wasserstein', ")
set(GUDHI_PYBIND11_MODULES "${GUDHI_PYBIND11_MODULES}'hera/bottleneck', ")
+ set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'nerve_gic', ")
if (NOT CGAL_VERSION VERSION_LESS 4.11.0)
+ set(GUDHI_PYBIND11_MODULES "${GUDHI_PYBIND11_MODULES}'datasets/generators/_points', ")
set(GUDHI_PYBIND11_MODULES "${GUDHI_PYBIND11_MODULES}'bottleneck', ")
- set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'nerve_gic', ")
endif ()
if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
- set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'alpha_complex', ")
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'subsampling', ")
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'tangential_complex', ")
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'euclidean_witness_complex', ")
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'euclidean_strong_witness_complex', ")
endif ()
+ if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
+ set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'alpha_complex', ")
+ endif ()
+
+ # from windows vcpkg eigen 3.4.0#2 : build fails with
+ # error C2440: '<function-style-cast>': cannot convert from 'Eigen::EigenBase<Derived>::Index' to '__gmp_expr<mpq_t,mpq_t>'
+ # cf. https://gitlab.com/libeigen/eigen/-/issues/2476
+ # Workaround is to compile with '-DEIGEN_DEFAULT_DENSE_INDEX_TYPE=int'
+ if (FORCE_EIGEN_DEFAULT_DENSE_INDEX_TYPE_TO_INT)
+ set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DEIGEN_DEFAULT_DENSE_INDEX_TYPE=int', ")
+ endif()
+
+ add_gudhi_debug_info("Boost version ${Boost_VERSION}")
if(CGAL_FOUND)
+ if(NOT CGAL_VERSION VERSION_LESS 5.3.0)
+ # CGAL_HEADER_ONLY has been dropped for CGAL >= 5.3. Only the header-only version is supported.
+ set(CGAL_HEADER_ONLY True)
+ endif(NOT CGAL_VERSION VERSION_LESS 5.3.0)
# Add CGAL compilation args
if(CGAL_HEADER_ONLY)
add_gudhi_debug_info("CGAL header only version ${CGAL_VERSION}")
@@ -170,7 +200,7 @@ if(PYTHONINTERP_FOUND)
else(CGAL_HEADER_ONLY)
add_gudhi_debug_info("CGAL version ${CGAL_VERSION}")
add_GUDHI_PYTHON_lib("${CGAL_LIBRARY}")
- add_GUDHI_PYTHON_lib_dir("${CGAL_LIBRARIES_DIR}")
+ add_GUDHI_PYTHON_lib_dir(${CGAL_LIBRARIES_DIR})
message("** Add CGAL ${CGAL_LIBRARIES_DIR}")
# If CGAL is not header only, CGAL library may link with boost system,
if(CMAKE_BUILD_TYPE MATCHES Debug)
@@ -178,7 +208,7 @@ if(PYTHONINTERP_FOUND)
else()
add_GUDHI_PYTHON_lib("${Boost_SYSTEM_LIBRARY_RELEASE}")
endif()
- add_GUDHI_PYTHON_lib_dir("${Boost_LIBRARY_DIRS}")
+ add_GUDHI_PYTHON_lib_dir(${Boost_LIBRARY_DIRS})
message("** Add Boost ${Boost_LIBRARY_DIRS}")
endif(CGAL_HEADER_ONLY)
# GMP and GMPXX are not required, but if present, CGAL will link with them.
@@ -190,15 +220,16 @@ if(PYTHONINTERP_FOUND)
get_filename_component(GMP_LIBRARIES_DIR ${GMP_LIBRARIES} PATH)
message("GMP_LIBRARIES_DIR from GMP_LIBRARIES set to ${GMP_LIBRARIES_DIR}")
endif(NOT GMP_LIBRARIES_DIR)
- add_GUDHI_PYTHON_lib_dir("${GMP_LIBRARIES_DIR}")
+ add_GUDHI_PYTHON_lib_dir(${GMP_LIBRARIES_DIR})
message("** Add gmp ${GMP_LIBRARIES_DIR}")
+ # When FORCE_CGAL_NOT_TO_BUILD_WITH_GMPXX is set, not defining CGAL_USE_GMPXX is sufficient enough
if(GMPXX_FOUND)
add_gudhi_debug_info("GMPXX_LIBRARIES = ${GMPXX_LIBRARIES}")
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DCGAL_USE_GMPXX', ")
add_GUDHI_PYTHON_lib("${GMPXX_LIBRARIES}")
- add_GUDHI_PYTHON_lib_dir("${GMPXX_LIBRARIES_DIR}")
+ add_GUDHI_PYTHON_lib_dir(${GMPXX_LIBRARIES_DIR})
message("** Add gmpxx ${GMPXX_LIBRARIES_DIR}")
- endif(GMPXX_FOUND)
+ endif()
endif(GMP_FOUND)
if(MPFR_FOUND)
add_gudhi_debug_info("MPFR_LIBRARIES = ${MPFR_LIBRARIES}")
@@ -209,17 +240,24 @@ if(PYTHONINTERP_FOUND)
get_filename_component(MPFR_LIBRARIES_DIR ${MPFR_LIBRARIES} PATH)
message("MPFR_LIBRARIES_DIR from MPFR_LIBRARIES set to ${MPFR_LIBRARIES_DIR}")
endif(NOT MPFR_LIBRARIES_DIR)
- add_GUDHI_PYTHON_lib_dir("${MPFR_LIBRARIES_DIR}")
+ add_GUDHI_PYTHON_lib_dir(${MPFR_LIBRARIES_DIR})
message("** Add mpfr ${MPFR_LIBRARIES_DIR}")
endif(MPFR_FOUND)
endif(CGAL_FOUND)
# Specific for Mac
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
- set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-mmacosx-version-min=10.12', ")
- set(GUDHI_PYTHON_EXTRA_LINK_ARGS "${GUDHI_PYTHON_EXTRA_LINK_ARGS}'-mmacosx-version-min=10.12', ")
+ set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-mmacosx-version-min=10.14', ")
+ set(GUDHI_PYTHON_EXTRA_LINK_ARGS "${GUDHI_PYTHON_EXTRA_LINK_ARGS}'-mmacosx-version-min=10.14', ")
endif(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
+ # Strip dynamic libraries in release mode for smaller pip packages under linux
+ if(CMAKE_COMPILER_IS_GNUCXX)
+ if(CMAKE_BUILD_TYPE MATCHES Release)
+ set(GUDHI_PYTHON_EXTRA_LINK_ARGS "${GUDHI_PYTHON_EXTRA_LINK_ARGS}'-s', ")
+ endif(CMAKE_BUILD_TYPE MATCHES Release)
+ endif(CMAKE_COMPILER_IS_GNUCXX)
+
# Loop on INCLUDE_DIRECTORIES PROPERTY
get_property(GUDHI_INCLUDE_DIRECTORIES DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)
foreach(GUDHI_INCLUDE_DIRECTORY ${GUDHI_INCLUDE_DIRECTORIES})
@@ -230,18 +268,22 @@ if(PYTHONINTERP_FOUND)
if (TBB_FOUND AND WITH_GUDHI_USE_TBB)
add_gudhi_debug_info("TBB version ${TBB_INTERFACE_VERSION} found and used")
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DGUDHI_USE_TBB', ")
- if(CMAKE_BUILD_TYPE MATCHES Debug)
+ if((CMAKE_BUILD_TYPE MATCHES Debug) AND TBB_DEBUG_LIBRARY)
add_GUDHI_PYTHON_lib("${TBB_DEBUG_LIBRARY}")
add_GUDHI_PYTHON_lib("${TBB_MALLOC_DEBUG_LIBRARY}")
else()
add_GUDHI_PYTHON_lib("${TBB_RELEASE_LIBRARY}")
add_GUDHI_PYTHON_lib("${TBB_MALLOC_RELEASE_LIBRARY}")
endif()
- add_GUDHI_PYTHON_lib_dir("${TBB_LIBRARY_DIRS}")
+ add_GUDHI_PYTHON_lib_dir(${TBB_LIBRARY_DIRS})
message("** Add tbb ${TBB_LIBRARY_DIRS}")
set(GUDHI_PYTHON_INCLUDE_DIRS "${GUDHI_PYTHON_INCLUDE_DIRS}'${TBB_INCLUDE_DIRS}', ")
endif()
+ if(DEBUG_TRACES)
+ set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DDEBUG_TRACES', ")
+ endif(DEBUG_TRACES)
+
if(UNIX AND WITH_GUDHI_PYTHON_RUNTIME_LIBRARY_DIRS)
set( GUDHI_PYTHON_RUNTIME_LIBRARY_DIRS "${GUDHI_PYTHON_LIBRARY_DIRS}")
endif(UNIX AND WITH_GUDHI_PYTHON_RUNTIME_LIBRARY_DIRS)
@@ -256,15 +298,20 @@ if(PYTHONINTERP_FOUND)
# Other .py files
file(COPY "gudhi/persistence_graphical_tools.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/representations" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/")
+ file(COPY "gudhi/tensorflow" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/")
file(COPY "gudhi/wasserstein" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/point_cloud" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/clustering" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi" FILES_MATCHING PATTERN "*.py")
file(COPY "gudhi/weighted_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/dtm_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/hera/__init__.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/hera")
+ file(COPY "gudhi/datasets" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi" FILES_MATCHING PATTERN "*.py")
+ file(COPY "gudhi/sklearn" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/")
+
# Some files for pip package
file(COPY "introduction.rst" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/")
+ file(COPY "pyproject.toml" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/")
add_custom_command(
OUTPUT gudhi.so
@@ -274,67 +321,74 @@ if(PYTHONINTERP_FOUND)
add_custom_target(python ALL DEPENDS gudhi.so
COMMENT "Do not forget to add ${CMAKE_CURRENT_BINARY_DIR}/ to your PYTHONPATH before using examples or tests")
- install(CODE "execute_process(COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/setup.py install)")
-
+ # Path separator management for windows
+ if (WIN32)
+ set(GUDHI_PYTHON_PATH_ENV "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR};$ENV{PYTHONPATH}")
+ else(WIN32)
+ set(GUDHI_PYTHON_PATH_ENV "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}:$ENV{PYTHONPATH}")
+ endif(WIN32)
# Documentation generation is available through sphinx - requires all modules
# Make it first as sphinx test is by far the longest test which is nice when testing in parallel
if(SPHINX_PATH)
- if(MATPLOTLIB_FOUND)
- if(NUMPY_FOUND)
- if(SCIPY_FOUND)
- if(SKLEARN_FOUND)
- if(OT_FOUND)
- if(PYBIND11_FOUND)
- if(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
- set (GUDHI_SPHINX_MESSAGE "Generating API documentation with Sphinx in ${CMAKE_CURRENT_BINARY_DIR}/sphinx/")
- # User warning - Sphinx is a static pages generator, and configured to work fine with user_version
- # Images and biblio warnings because not found on developper version
- if (GUDHI_PYTHON_PATH STREQUAL "src/python")
- set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developper version. Images and biblio will miss")
- endif()
- # sphinx target requires gudhi.so, because conf.py reads gudhi version from it
- add_custom_target(sphinx
- WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/doc
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${SPHINX_PATH} -b html ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/sphinx
- DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/gudhi.so"
- COMMENT "${GUDHI_SPHINX_MESSAGE}" VERBATIM)
-
- add_test(NAME sphinx_py_test
- WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${SPHINX_PATH} -b doctest ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/doctest)
-
- # Set missing or not modules
- set(GUDHI_MODULES ${GUDHI_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MODULES")
- else(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
- message("++ Python documentation module will not be compiled because it requires a Eigen3 and CGAL version >= 4.11.0")
+ if(SPHINX_PARAMLINKS_FOUND)
+ if(PYTHON_DOCS_THEME_FOUND)
+ if(MATPLOTLIB_FOUND)
+ if(NUMPY_FOUND)
+ if(SCIPY_FOUND)
+ if(SKLEARN_FOUND)
+ if(OT_FOUND)
+ if(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
+ set (GUDHI_SPHINX_MESSAGE "Generating API documentation with Sphinx in ${CMAKE_CURRENT_BINARY_DIR}/sphinx/")
+ # User warning - Sphinx is a static pages generator, and configured to work fine with user_version
+ # Images and biblio warnings because not found on developer version
+ if (GUDHI_PYTHON_PATH STREQUAL "src/python")
+ set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developer version. Images and biblio will miss")
+ endif()
+ # sphinx target requires gudhi.so, because conf.py reads gudhi version from it
+ add_custom_target(sphinx
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/doc
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
+ ${SPHINX_PATH} -b html ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/sphinx
+ DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/gudhi.so"
+ COMMENT "${GUDHI_SPHINX_MESSAGE}" VERBATIM)
+ add_test(NAME sphinx_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
+ ${SPHINX_PATH} -b doctest ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/doctest)
+ # Set missing or not modules
+ set(GUDHI_MODULES ${GUDHI_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MODULES")
+ else(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
+ message("++ Python documentation module will not be compiled because it requires a Eigen3 and CGAL version >= 5.1.0")
+ set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
+ endif(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
+ else(OT_FOUND)
+ message("++ Python documentation module will not be compiled because POT was not found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
- endif(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
- else(PYBIND11_FOUND)
- message("++ Python documentation module will not be compiled because pybind11 was not found")
+ endif(OT_FOUND)
+ else(SKLEARN_FOUND)
+ message("++ Python documentation module will not be compiled because scikit-learn was not found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
- endif(PYBIND11_FOUND)
- else(OT_FOUND)
- message("++ Python documentation module will not be compiled because POT was not found")
+ endif(SKLEARN_FOUND)
+ else(SCIPY_FOUND)
+ message("++ Python documentation module will not be compiled because scipy was not found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
- endif(OT_FOUND)
- else(SKLEARN_FOUND)
- message("++ Python documentation module will not be compiled because scikit-learn was not found")
+ endif(SCIPY_FOUND)
+ else(NUMPY_FOUND)
+ message("++ Python documentation module will not be compiled because numpy was not found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
- endif(SKLEARN_FOUND)
- else(SCIPY_FOUND)
- message("++ Python documentation module will not be compiled because scipy was not found")
+ endif(NUMPY_FOUND)
+ else(MATPLOTLIB_FOUND)
+ message("++ Python documentation module will not be compiled because matplotlib was not found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
- endif(SCIPY_FOUND)
- else(NUMPY_FOUND)
- message("++ Python documentation module will not be compiled because numpy was not found")
+ endif(MATPLOTLIB_FOUND)
+ else(PYTHON_DOCS_THEME_FOUND)
+ message("++ Python documentation module will not be compiled because python-docs-theme was not found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
- endif(NUMPY_FOUND)
- else(MATPLOTLIB_FOUND)
- message("++ Python documentation module will not be compiled because matplotlib was not found")
+ endif(PYTHON_DOCS_THEME_FOUND)
+ else(SPHINX_PARAMLINKS_FOUND)
+ message("++ Python documentation module will not be compiled because sphinxcontrib-paramlinks was not found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
- endif(MATPLOTLIB_FOUND)
+ endif(SPHINX_PARAMLINKS_FOUND)
else(SPHINX_PATH)
message("++ Python documentation module will not be compiled because sphinx and sphinxcontrib-bibtex were not found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
@@ -342,17 +396,19 @@ if(PYTHONINTERP_FOUND)
# Test examples
- if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
+ if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
# Bottleneck and Alpha
add_test(NAME alpha_rips_persistence_bottleneck_distance_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/alpha_rips_persistence_bottleneck_distance.py"
-f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -t 0.15 -d 3)
+ endif (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
+ if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
# Tangential
add_test(NAME tangential_complex_plain_homology_from_off_file_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/tangential_complex_plain_homology_from_off_file_example.py"
--no-diagram -i 2 -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off)
@@ -361,13 +417,13 @@ if(PYTHONINTERP_FOUND)
# Witness complex
add_test(NAME euclidean_strong_witness_complex_diagram_persistence_from_off_file_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/euclidean_strong_witness_complex_diagram_persistence_from_off_file_example.py"
--no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -a 1.0 -n 20 -d 2)
add_test(NAME euclidean_witness_complex_diagram_persistence_from_off_file_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/euclidean_witness_complex_diagram_persistence_from_off_file_example.py"
--no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -a 1.0 -n 20 -d 2)
@@ -379,75 +435,80 @@ if(PYTHONINTERP_FOUND)
# Bottleneck
add_test(NAME bottleneck_basic_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/bottleneck_basic_example.py")
- if (PYBIND11_FOUND)
- add_gudhi_py_test(test_bottleneck_distance)
- endif()
+ add_gudhi_py_test(test_bottleneck_distance)
+ endif (NOT CGAL_VERSION VERSION_LESS 4.11.0)
- # Cover complex
- file(COPY ${CMAKE_SOURCE_DIR}/data/points/human.off DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/)
- file(COPY ${CMAKE_SOURCE_DIR}/data/points/COIL_database/lucky_cat.off DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/)
- file(COPY ${CMAKE_SOURCE_DIR}/data/points/COIL_database/lucky_cat_PCA1 DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/)
- add_test(NAME cover_complex_nerve_example_py_test
- WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/nerve_of_a_covering.py"
- -f human.off -c 2 -r 10 -g 0.3)
+ # Cover complex
+ file(COPY ${CMAKE_SOURCE_DIR}/data/points/human.off DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/)
+ file(COPY ${CMAKE_SOURCE_DIR}/data/points/COIL_database/lucky_cat.off DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/)
+ file(COPY ${CMAKE_SOURCE_DIR}/data/points/COIL_database/lucky_cat_PCA1 DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/)
+ add_test(NAME cover_complex_nerve_example_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
+ ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/nerve_of_a_covering.py"
+ -f human.off -c 2 -r 10 -g 0.3)
- add_test(NAME cover_complex_coordinate_gic_example_py_test
- WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/coordinate_graph_induced_complex.py"
- -f human.off -c 0 -v)
+ add_test(NAME cover_complex_coordinate_gic_example_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
+ ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/coordinate_graph_induced_complex.py"
+ -f human.off -c 0 -v)
- add_test(NAME cover_complex_functional_gic_example_py_test
- WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/functional_graph_induced_complex.py"
- -o lucky_cat.off
- -f lucky_cat_PCA1 -v)
+ add_test(NAME cover_complex_functional_gic_example_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
+ ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/functional_graph_induced_complex.py"
+ -o lucky_cat.off
+ -f lucky_cat_PCA1 -v)
- add_test(NAME cover_complex_voronoi_gic_example_py_test
- WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/voronoi_graph_induced_complex.py"
- -f human.off -n 700 -v)
+ add_test(NAME cover_complex_voronoi_gic_example_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
+ ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/voronoi_graph_induced_complex.py"
+ -f human.off -n 700 -v)
- add_gudhi_py_test(test_cover_complex)
- endif (NOT CGAL_VERSION VERSION_LESS 4.11.0)
+ add_gudhi_py_test(test_cover_complex)
- if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
+ if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
# Alpha
add_test(NAME alpha_complex_from_points_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/alpha_complex_from_points_example.py")
+ add_test(NAME alpha_complex_from_generated_points_on_sphere_example_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
+ ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/alpha_complex_from_generated_points_on_sphere_example.py")
add_test(NAME alpha_complex_diagram_persistence_from_off_file_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/alpha_complex_diagram_persistence_from_off_file_example.py"
- --no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -a 0.6)
+ --no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off)
add_gudhi_py_test(test_alpha_complex)
- endif (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
+ endif (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
if (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
# Euclidean witness
add_gudhi_py_test(test_euclidean_witness_complex)
+ # Datasets generators
+ add_gudhi_py_test(test_datasets_generators) # TODO separate full python datasets generators in another test file independent from CGAL ?
+
endif (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
# Cubical
add_test(NAME periodic_cubical_complex_barcode_persistence_from_perseus_file_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/periodic_cubical_complex_barcode_persistence_from_perseus_file_example.py"
--no-barcode -f ${CMAKE_SOURCE_DIR}/data/bitmap/CubicalTwoSphere.txt)
add_test(NAME random_cubical_complex_persistence_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/random_cubical_complex_persistence_example.py"
10 10 10)
@@ -456,19 +517,19 @@ if(PYTHONINTERP_FOUND)
# Rips
add_test(NAME rips_complex_diagram_persistence_from_distance_matrix_file_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py"
- --no-diagram -f ${CMAKE_SOURCE_DIR}/data/distance_matrix/lower_triangular_distance_matrix.csv -e 12.0 -d 3)
+ --no-diagram -f ${CMAKE_SOURCE_DIR}/data/distance_matrix/lower_triangular_distance_matrix.csv -s , -e 12.0 -d 3)
add_test(NAME rips_complex_diagram_persistence_from_off_file_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/example/rips_complex_diagram_persistence_from_off_file_example.py
--no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -e 0.25 -d 3)
add_test(NAME rips_complex_from_points_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/example/rips_complex_from_points_example.py)
add_gudhi_py_test(test_rips_complex)
@@ -476,7 +537,7 @@ if(PYTHONINTERP_FOUND)
# Simplex tree
add_test(NAME simplex_tree_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/example/simplex_tree_example.py)
add_gudhi_py_test(test_simplex_tree)
@@ -485,23 +546,24 @@ if(PYTHONINTERP_FOUND)
# Witness
add_test(NAME witness_complex_from_nearest_landmark_table_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}"
${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/example/witness_complex_from_nearest_landmark_table.py)
add_gudhi_py_test(test_witness_complex)
# Reader utils
add_gudhi_py_test(test_reader_utils)
+ add_gudhi_py_test(test_off)
# Wasserstein
- if(OT_FOUND AND PYBIND11_FOUND)
+ if(OT_FOUND)
# EagerPy dependency because of enable_autodiff=True
if(EAGERPY_FOUND)
add_gudhi_py_test(test_wasserstein_distance)
endif()
+
add_gudhi_py_test(test_wasserstein_barycenter)
- endif()
- if(OT_FOUND)
+
if(TORCH_FOUND AND TENSORFLOW_FOUND AND EAGERPY_FOUND)
add_gudhi_py_test(test_wasserstein_with_tensors)
endif()
@@ -512,6 +574,26 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_representations)
endif()
+ # Differentiation
+ if(TENSORFLOW_FOUND)
+ add_gudhi_py_test(test_diff)
+ endif()
+
+ # Perslay
+ if(TENSORFLOW_FOUND AND SKLEARN_FOUND)
+ add_gudhi_py_test(test_perslay)
+ endif()
+
+ # Betti curves
+ if(SKLEARN_FOUND AND SCIPY_FOUND)
+ add_gudhi_py_test(test_betti_curve_representations)
+ endif()
+
+ # Representations preprocessing
+ if(SKLEARN_FOUND)
+ add_gudhi_py_test(test_representations_preprocessing)
+ endif()
+
# Time Delay
add_gudhi_py_test(test_time_delay)
@@ -522,7 +604,7 @@ if(PYTHONINTERP_FOUND)
endif()
# Tomato
- if(SCIPY_FOUND AND SKLEARN_FOUND AND PYBIND11_FOUND)
+ if(SCIPY_FOUND AND SKLEARN_FOUND)
add_gudhi_py_test(test_tomato)
endif()
@@ -536,13 +618,27 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_dtm_rips_complex)
endif()
+ # Fetch remote datasets
+ if(WITH_GUDHI_REMOTE_TEST)
+ add_gudhi_py_test(test_remote_datasets)
+ endif()
+
+ # sklearn
+ if(SKLEARN_FOUND)
+ add_gudhi_py_test(test_sklearn_cubical_persistence)
+ endif()
+
+ # persistence graphical tools
+ if(MATPLOTLIB_FOUND)
+ add_gudhi_py_test(test_persistence_graphical_tools)
+ endif()
# Set missing or not modules
set(GUDHI_MODULES ${GUDHI_MODULES} "python" CACHE INTERNAL "GUDHI_MODULES")
- else(CYTHON_FOUND)
- message("++ Python module will not be compiled because cython was not found")
+ else(NUMPY_FOUND AND PYBIND11_FOUND AND CYTHON_FOUND)
+ message("++ Python module will not be compiled because numpy and/or cython and/or pybind11 was/were not found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python" CACHE INTERNAL "GUDHI_MISSING_MODULES")
- endif(CYTHON_FOUND)
+ endif(NUMPY_FOUND AND PYBIND11_FOUND AND CYTHON_FOUND)
else(PYTHONINTERP_FOUND)
message("++ Python module will not be compiled because no Python interpreter was found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python" CACHE INTERNAL "GUDHI_MISSING_MODULES")
diff --git a/src/python/doc/_templates/layout.html b/src/python/doc/_templates/layout.html
index cd40a51b..e074b6c7 100644
--- a/src/python/doc/_templates/layout.html
+++ b/src/python/doc/_templates/layout.html
@@ -194,6 +194,7 @@
<li><a href="/relatedprojects/">Related projects</a></li>
<li><a href="/theyaretalkingaboutus/">They are talking about us</a></li>
<li><a href="/inaction/">GUDHI in action</a></li>
+ <li><a href="/etymology/">Etymology</a></li>
</ul>
</li>
<li class="divider"></li>
diff --git a/src/python/doc/alpha_complex_ref.rst b/src/python/doc/alpha_complex_ref.rst
index 7da79543..eaa72551 100644
--- a/src/python/doc/alpha_complex_ref.rst
+++ b/src/python/doc/alpha_complex_ref.rst
@@ -9,6 +9,5 @@ Alpha complex reference manual
.. autoclass:: gudhi.AlphaComplex
:members:
:undoc-members:
- :show-inheritance:
.. automethod:: gudhi.AlphaComplex.__init__
diff --git a/src/python/doc/alpha_complex_sum.inc b/src/python/doc/alpha_complex_sum.inc
index aeab493f..5c76fd54 100644
--- a/src/python/doc/alpha_complex_sum.inc
+++ b/src/python/doc/alpha_complex_sum.inc
@@ -1,15 +1,15 @@
.. table::
:widths: 30 40 30
- +----------------------------------------------------------------+-------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
- | .. figure:: | Alpha complex is a simplicial complex constructed from the finite | :Author: Vincent Rouvreau |
- | ../../doc/Alpha_complex/alpha_complex_representation.png | cells of a Delaunay Triangulation. It has the same persistent homology | |
- | :alt: Alpha complex representation | as the Čech complex and is significantly smaller. | :Since: GUDHI 2.0.0 |
- | :figclass: align-center | | |
- | | | :License: MIT (`GPL v3 </licensing/>`_) |
- | | | |
- | | | :Requires: `Eigen <installation.html#eigen>`_ :math:`\geq` 3.1.0 and `CGAL <installation.html#cgal>`_ :math:`\geq` 4.11.0 |
- | | | |
- +----------------------------------------------------------------+-------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
- | * :doc:`alpha_complex_user` | * :doc:`alpha_complex_ref` |
- +----------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
+ +----------------------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+
+ | .. figure:: | Alpha complex is a simplicial complex constructed from the finite | :Author: Vincent Rouvreau |
+ | ../../doc/Alpha_complex/alpha_complex_representation.png | cells of a Delaunay Triangulation. It has the same persistent homology | |
+ | :alt: Alpha complex representation | as the Čech complex and is significantly smaller. | :Since: GUDHI 2.0.0 |
+ | :figclass: align-center | | |
+ | | | :License: MIT (`GPL v3 </licensing/>`_) |
+ | | | |
+ | | | :Requires: `Eigen <installation.html#eigen>`_ :math:`\geq` 3.1.0 and `CGAL <installation.html#cgal>`_ :math:`\geq` 5.1 |
+ | | | |
+ +----------------------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+
+ | * :doc:`alpha_complex_user` | * :doc:`alpha_complex_ref` |
+ +----------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
diff --git a/src/python/doc/alpha_complex_user.rst b/src/python/doc/alpha_complex_user.rst
index fffcb3db..9e67d38a 100644
--- a/src/python/doc/alpha_complex_user.rst
+++ b/src/python/doc/alpha_complex_user.rst
@@ -9,7 +9,7 @@ Definition
.. include:: alpha_complex_sum.inc
-:doc:`AlphaComplex <alpha_complex_ref>` is constructing a :doc:`SimplexTree <simplex_tree_ref>` using
+:class:`~gudhi.AlphaComplex` is constructing a :doc:`SimplexTree <simplex_tree_ref>` using
`Delaunay Triangulation <http://doc.cgal.org/latest/Triangulation/index.html#Chapter_Triangulations>`_
:cite:`cgal:hdj-t-19b` from the `Computational Geometry Algorithms Library <http://www.cgal.org/>`_
:cite:`cgal:eb-19b`.
@@ -27,15 +27,13 @@ Remarks
If you pass :code:`precision = 'exact'` to :func:`~gudhi.AlphaComplex.__init__`, the filtration values are the exact
ones converted to float. This can be very slow.
If you pass :code:`precision = 'safe'` (the default), the filtration values are only
- guaranteed to have a small multiplicative error compared to the exact value.
+ guaranteed to have a small multiplicative error compared to the exact value, see
+ :func:`~gudhi.AlphaComplex.set_float_relative_precision` to modify the precision.
A drawback, when computing persistence, is that an empty exact interval [10^12,10^12] may become a
non-empty approximate interval [10^12,10^12+10^6].
Using :code:`precision = 'fast'` makes the computations slightly faster, and the combinatorics are still exact, but
the computation of filtration values can exceptionally be arbitrarily bad. In all cases, we still guarantee that the
output is a valid filtration (faces have a filtration value no larger than their cofaces).
-* For performances reasons, it is advised to use Alpha_complex with `CGAL <installation.html#cgal>`_ :math:`\geq` 5.0.0.
-* The vertices in the output simplex tree are not guaranteed to match the order of the input points. One can use
- :func:`~gudhi.AlphaComplex.get_point` to get the initial point back.
Example from points
-------------------
@@ -44,23 +42,22 @@ This example builds the alpha-complex from the given points:
.. testcode::
- import gudhi
- alpha_complex = gudhi.AlphaComplex(points=[[1, 1], [7, 0], [4, 6], [9, 6], [0, 14], [2, 19], [9, 17]])
+ from gudhi import AlphaComplex
+ ac = AlphaComplex(points=[[1, 1], [7, 0], [4, 6], [9, 6], [0, 14], [2, 19], [9, 17]])
+
+ stree = ac.create_simplex_tree()
+ print('Alpha complex is of dimension ', stree.dimension(), ' - ',
+ stree.num_simplices(), ' simplices - ', stree.num_vertices(), ' vertices.')
- simplex_tree = alpha_complex.create_simplex_tree()
- result_str = 'Alpha complex is of dimension ' + repr(simplex_tree.dimension()) + ' - ' + \
- repr(simplex_tree.num_simplices()) + ' simplices - ' + \
- repr(simplex_tree.num_vertices()) + ' vertices.'
- print(result_str)
fmt = '%s -> %.2f'
- for filtered_value in simplex_tree.get_filtration():
+ for filtered_value in stree.get_filtration():
print(fmt % tuple(filtered_value))
The output is:
.. testoutput::
- Alpha complex is of dimension 2 - 25 simplices - 7 vertices.
+ Alpha complex is of dimension 2 - 25 simplices - 7 vertices.
[0] -> 0.00
[1] -> 0.00
[2] -> 0.00
@@ -163,7 +160,10 @@ As the squared radii computed by CGAL are an approximation, it might happen that
:math:`\alpha^2` values do not quite define a proper filtration (i.e. non-decreasing with
respect to inclusion).
We fix that up by calling :func:`~gudhi.SimplexTree.make_filtration_non_decreasing` (cf.
-`C++ version <http://gudhi.gforge.inria.fr/doc/latest/index.html>`_).
+`C++ version <https://gudhi.inria.fr/doc/latest/class_gudhi_1_1_simplex__tree.html>`_).
+
+.. note::
+ This is not the case in `exact` version, this is the reason why it is not called in this case.
Prune above given filtration value
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -174,11 +174,75 @@ of speed-up, since we always first build the full filtered complex, so it is rec
:paramref:`~gudhi.AlphaComplex.create_simplex_tree.max_alpha_square`.
In the following example, a threshold of :math:`\alpha^2 = 32.0` is used.
+Weighted version
+^^^^^^^^^^^^^^^^
+
+A weighted version for Alpha complex is available. It is like a usual Alpha complex, but based on a
+`CGAL regular triangulation <https://doc.cgal.org/latest/Triangulation/index.html#TriangulationSecRT>`_.
+
+This example builds the weighted alpha-complex of a small molecule, where atoms have different sizes.
+It is taken from
+`CGAL 3d weighted alpha shapes <https://doc.cgal.org/latest/Alpha_shapes_3/index.html#AlphaShape_3DExampleforWeightedAlphaShapes>`_.
+
+Then, it is asked to display information about the alpha complex.
+
+.. testcode::
+
+ from gudhi import AlphaComplex
+ wgt_ac = AlphaComplex(points=[[ 1., -1., -1.],
+ [-1., 1., -1.],
+ [-1., -1., 1.],
+ [ 1., 1., 1.],
+ [ 2., 2., 2.]],
+ weights = [4., 4., 4., 4., 1.])
+
+ stree = wgt_ac.create_simplex_tree()
+ print('Weighted alpha complex is of dimension ', stree.dimension(), ' - ',
+ stree.num_simplices(), ' simplices - ', stree.num_vertices(), ' vertices.')
+ fmt = '%s -> %.2f'
+ for simplex in stree.get_simplices():
+ print(fmt % tuple(simplex))
+
+The output is:
+
+.. testoutput::
+
+ Weighted alpha complex is of dimension 3 - 29 simplices - 5 vertices.
+ [0, 1, 2, 3] -> -1.00
+ [0, 1, 2] -> -1.33
+ [0, 1, 3, 4] -> 95.00
+ [0, 1, 3] -> -1.33
+ [0, 1, 4] -> 95.00
+ [0, 1] -> -2.00
+ [0, 2, 3, 4] -> 95.00
+ [0, 2, 3] -> -1.33
+ [0, 2, 4] -> 95.00
+ [0, 2] -> -2.00
+ [0, 3, 4] -> 23.00
+ [0, 3] -> -2.00
+ [0, 4] -> 23.00
+ [0] -> -4.00
+ [1, 2, 3, 4] -> 95.00
+ [1, 2, 3] -> -1.33
+ [1, 2, 4] -> 95.00
+ [1, 2] -> -2.00
+ [1, 3, 4] -> 23.00
+ [1, 3] -> -2.00
+ [1, 4] -> 23.00
+ [1] -> -4.00
+ [2, 3, 4] -> 23.00
+ [2, 3] -> -2.00
+ [2, 4] -> 23.00
+ [2] -> -4.00
+ [3, 4] -> -1.00
+ [3] -> -4.00
+ [4] -> -1.00
Example from OFF file
^^^^^^^^^^^^^^^^^^^^^
-This example builds the alpha complex from 300 random points on a 2-torus.
+This example builds the alpha complex from 300 random points on a 2-torus, given by an
+`OFF file <fileformats.html#off-file-format>`_.
Then, it computes the persistence diagram and displays it:
@@ -186,14 +250,10 @@ Then, it computes the persistence diagram and displays it:
:include-source:
import matplotlib.pyplot as plt
- import gudhi
- alpha_complex = gudhi.AlphaComplex(off_file=gudhi.__root_source_dir__ + \
- '/data/points/tore3D_300.off')
- simplex_tree = alpha_complex.create_simplex_tree()
- result_str = 'Alpha complex is of dimension ' + repr(simplex_tree.dimension()) + ' - ' + \
- repr(simplex_tree.num_simplices()) + ' simplices - ' + \
- repr(simplex_tree.num_vertices()) + ' vertices.'
- print(result_str)
- diag = simplex_tree.persistence()
- gudhi.plot_persistence_diagram(diag)
+ import gudhi as gd
+ off_file = gd.__root_source_dir__ + '/data/points/tore3D_300.off'
+ points = gd.read_points_from_off_file(off_file = off_file)
+ stree = gd.AlphaComplex(points = points).create_simplex_tree()
+ dgm = stree.persistence()
+ gd.plot_persistence_diagram(dgm, legend = True)
plt.show()
diff --git a/src/python/doc/clustering.rst b/src/python/doc/clustering.rst
index c5a57d3c..62422682 100644
--- a/src/python/doc/clustering.rst
+++ b/src/python/doc/clustering.rst
@@ -17,9 +17,8 @@ As a by-product, we produce the persistence diagram of the merge tree of the ini
:include-source:
import gudhi
- f = open(gudhi.__root_source_dir__ + '/data/points/spiral_2d.csv', 'r')
- import numpy as np
- data = np.loadtxt(f)
+ from gudhi.datasets.remote import fetch_spiral_2d
+ data = fetch_spiral_2d()
import matplotlib.pyplot as plt
plt.scatter(data[:,0],data[:,1],marker='.',s=1)
plt.show()
diff --git a/src/python/doc/conf.py b/src/python/doc/conf.py
index b06baf9c..e69e2751 100755
--- a/src/python/doc/conf.py
+++ b/src/python/doc/conf.py
@@ -120,15 +120,12 @@ pygments_style = 'sphinx'
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
-html_theme = 'classic'
+html_theme = 'python_docs_theme'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
html_theme_options = {
- "sidebarbgcolor": "#A1ADCD",
- "sidebartextcolor": "black",
- "sidebarlinkcolor": "#334D5C",
"body_max_width": "100%",
}
diff --git a/src/python/doc/cubical_complex_sklearn_itf_ref.rst b/src/python/doc/cubical_complex_sklearn_itf_ref.rst
new file mode 100644
index 00000000..90ae9ccd
--- /dev/null
+++ b/src/python/doc/cubical_complex_sklearn_itf_ref.rst
@@ -0,0 +1,102 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+Cubical complex persistence scikit-learn like interface
+#######################################################
+
+.. list-table::
+ :width: 100%
+ :header-rows: 0
+
+ * - :Since: GUDHI 3.6.0
+ - :License: MIT
+ - :Requires: `Scikit-learn <installation.html#scikit-learn>`_
+
+Cubical complex persistence scikit-learn like interface example
+---------------------------------------------------------------
+
+In this example, hand written digits are used as an input.
+a TDA scikit-learn pipeline is constructed and is composed of:
+
+#. :class:`~gudhi.sklearn.cubical_persistence.CubicalPersistence` that builds a cubical complex from the inputs and
+ returns its persistence diagrams
+#. :class:`~gudhi.representations.preprocessing.DiagramSelector` that removes non-finite persistence diagrams values
+#. :class:`~gudhi.representations.vector_methods.PersistenceImage` that builds the persistence images from persistence diagrams
+#. `SVC <https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html>`_ which is a scikit-learn support
+ vector classifier.
+
+This ML pipeline is trained to detect if the hand written digit is an '8' or not, thanks to the fact that an '8' has
+two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected components in :math:`\mathbf{H}_0`.
+
+.. code-block:: python
+
+ # Standard scientific Python imports
+ import numpy as np
+
+ # Standard scikit-learn imports
+ from sklearn.datasets import fetch_openml
+ from sklearn.pipeline import Pipeline
+ from sklearn.model_selection import train_test_split
+ from sklearn.svm import SVC
+ from sklearn import metrics
+
+ # Import TDA pipeline requirements
+ from gudhi.sklearn.cubical_persistence import CubicalPersistence
+ from gudhi.representations import PersistenceImage, DiagramSelector
+
+ X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
+
+ # Target is: "is an eight ?"
+ y = (y == "8") * 1
+ print("There are", np.sum(y), "eights out of", len(y), "numbers.")
+
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
+ pipe = Pipeline(
+ [
+ ("cub_pers", CubicalPersistence(homology_dimensions=0, newshape=[-1, 28, 28], n_jobs=-2)),
+ # Or for multiple persistence dimension computation
+ # ("cub_pers", CubicalPersistence(homology_dimensions=[0, 1], newshape=[-1, 28, 28])),
+ # ("H0_diags", DimensionSelector(index=0), # where index is the index in homology_dimensions array
+ ("finite_diags", DiagramSelector(use=True, point_type="finite")),
+ (
+ "pers_img",
+ PersistenceImage(bandwidth=50, weight=lambda x: x[1] ** 2, im_range=[0, 256, 0, 256], resolution=[20, 20]),
+ ),
+ ("svc", SVC()),
+ ]
+ )
+
+ # Learn from the train subset
+ pipe.fit(X_train, y_train)
+ # Predict from the test subset
+ predicted = pipe.predict(X_test)
+
+ print(f"Classification report for TDA pipeline {pipe}:\n" f"{metrics.classification_report(y_test, predicted)}\n")
+
+.. code-block:: none
+
+ There are 6825 eights out of 70000 numbers.
+ Classification report for TDA pipeline Pipeline(steps=[('cub_pers',
+ CubicalPersistence(newshape=[28, 28], n_jobs=-2)),
+ ('finite_diags', DiagramSelector(use=True)),
+ ('pers_img',
+ PersistenceImage(bandwidth=50, im_range=[0, 256, 0, 256],
+ weight=<function <lambda> at 0x7f3e54137ae8>)),
+ ('svc', SVC())]):
+ precision recall f1-score support
+
+ 0 0.97 0.99 0.98 25284
+ 1 0.92 0.68 0.78 2716
+
+ accuracy 0.96 28000
+ macro avg 0.94 0.84 0.88 28000
+ weighted avg 0.96 0.96 0.96 28000
+
+Cubical complex persistence scikit-learn like interface reference
+-----------------------------------------------------------------
+
+.. autoclass:: gudhi.sklearn.cubical_persistence.CubicalPersistence
+ :members:
+ :special-members: __init__
+ :show-inheritance: \ No newline at end of file
diff --git a/src/python/doc/cubical_complex_sum.inc b/src/python/doc/cubical_complex_sum.inc
index 87db184d..b27843e5 100644
--- a/src/python/doc/cubical_complex_sum.inc
+++ b/src/python/doc/cubical_complex_sum.inc
@@ -1,14 +1,22 @@
.. table::
:widths: 30 40 30
- +--------------------------------------------------------------------------+----------------------------------------------------------------------+-----------------------------+
- | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko |
- | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | |
- | :alt: Cubical complex representation | | :Since: GUDHI 2.0.0 |
- | :figclass: align-center | | |
- | | | :License: MIT |
- | | | |
- +--------------------------------------------------------------------------+----------------------------------------------------------------------+-----------------------------+
- | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
- | | * :doc:`periodic_cubical_complex_ref` |
- +--------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko |
+ | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | :Since: GUDHI 2.0.0 |
+ | :alt: Cubical complex representation | | :License: MIT |
+ | :figclass: align-center | | |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
+ | | * :doc:`periodic_cubical_complex_ref` |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | .. image:: | * :doc:`cubical_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
+ | img/tensorflow.png | | |
+ | :target: https://www.tensorflow.org | | |
+ | :height: 30 | | |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
+ | .. image:: | * :doc:`cubical_complex_sklearn_itf_ref` | :Requires: `Scikit-learn <installation.html#scikit-learn>`_ |
+ | img/sklearn.png | | |
+ | :target: https://scikit-learn.org | | |
+ | :height: 30 | | |
+ +--------------------------------------------------------------------------+--------------------------------------------------------------+-------------------------------------------------------------+
diff --git a/src/python/doc/cubical_complex_tflow_itf_ref.rst b/src/python/doc/cubical_complex_tflow_itf_ref.rst
new file mode 100644
index 00000000..b32f5e47
--- /dev/null
+++ b/src/python/doc/cubical_complex_tflow_itf_ref.rst
@@ -0,0 +1,40 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+TensorFlow layer for cubical persistence
+########################################
+
+.. include:: differentiation_sum.inc
+
+Example of gradient computed from cubical persistence
+-----------------------------------------------------
+
+.. testcode::
+
+ from gudhi.tensorflow import CubicalLayer
+ import tensorflow as tf
+
+ X = tf.Variable([[0.,2.,2.],[2.,2.,2.],[2.,2.,1.]], dtype=tf.float32, trainable=True)
+ cl = CubicalLayer(homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = cl.call(X)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+
+ grads = tape.gradient(loss, [X])
+ print(grads[0].numpy())
+
+.. testoutput::
+
+ [[ 0. 0. 0. ]
+ [ 0. 0.5 0. ]
+ [ 0. 0. -0.5]]
+
+Documentation for CubicalLayer
+------------------------------
+
+.. autoclass:: gudhi.tensorflow.CubicalLayer
+ :members:
+ :special-members: __init__
+ :show-inheritance:
diff --git a/src/python/doc/cubical_complex_user.rst b/src/python/doc/cubical_complex_user.rst
index 6a211347..42a23875 100644
--- a/src/python/doc/cubical_complex_user.rst
+++ b/src/python/doc/cubical_complex_user.rst
@@ -7,14 +7,7 @@ Cubical complex user manual
Definition
----------
-===================================== ===================================== =====================================
-:Author: Pawel Dlotko :Since: GUDHI PYTHON 2.0.0 :License: GPL v3
-===================================== ===================================== =====================================
-
-+---------------------------------------------+----------------------------------------------------------------------+
-| :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
-| | * :doc:`periodic_cubical_complex_ref` |
-+---------------------------------------------+----------------------------------------------------------------------+
+.. include:: cubical_complex_sum.inc
The cubical complex is an example of a structured complex useful in computational mathematics (specially rigorous
numerics) and image analysis.
@@ -163,4 +156,4 @@ Tutorial
--------
This `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-cubical-complexes.ipynb>`_
-explains how to represent sublevels sets of functions using cubical complexes. \ No newline at end of file
+explains how to represent sublevels sets of functions using cubical complexes.
diff --git a/src/python/doc/datasets.inc b/src/python/doc/datasets.inc
new file mode 100644
index 00000000..95a87678
--- /dev/null
+++ b/src/python/doc/datasets.inc
@@ -0,0 +1,14 @@
+.. table::
+ :widths: 30 40 30
+
+ +-----------------------------------+--------------------------------------------+--------------------------------------------------------------------------------------+
+ | .. figure:: | Datasets either generated or fetched. | :Authors: Hind Montassif |
+ | img/sphere_3d.png | | |
+ | | | :Since: GUDHI 3.5.0 |
+ | | | |
+ | | | :License: MIT (`LGPL v3 </licensing/>`_) |
+ | | | |
+ | | | :Requires: `CGAL <installation.html#cgal>`_ |
+ +-----------------------------------+--------------------------------------------+--------------------------------------------------------------------------------------+
+ | * :doc:`datasets` |
+ +-----------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+
diff --git a/src/python/doc/datasets.rst b/src/python/doc/datasets.rst
new file mode 100644
index 00000000..2d11a19d
--- /dev/null
+++ b/src/python/doc/datasets.rst
@@ -0,0 +1,133 @@
+
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+================
+Datasets manual
+================
+
+Datasets generators
+===================
+
+We provide the generation of different customizable datasets to use as inputs for Gudhi complexes and data structures.
+
+Points generators
+------------------
+
+The module **points** enables the generation of random points on a sphere, random points on a torus and as a grid.
+
+Points on sphere
+^^^^^^^^^^^^^^^^
+
+The function **sphere** enables the generation of random i.i.d. points uniformly on a (d-1)-sphere in :math:`R^d`.
+The user should provide the number of points to be generated on the sphere :code:`n_samples` and the ambient dimension :code:`ambient_dim`.
+The :code:`radius` of sphere is optional and is equal to **1** by default.
+Only random points generation is currently available.
+
+The generated points are given as an array of shape :math:`(n\_samples, ambient\_dim)`.
+
+Example
+"""""""
+
+.. code-block:: python
+
+ from gudhi.datasets.generators import points
+ from gudhi import AlphaComplex
+
+ # Generate 50 points on a sphere in R^2
+ gen_points = points.sphere(n_samples = 50, ambient_dim = 2, radius = 1, sample = "random")
+
+ # Create an alpha complex from the generated points
+ alpha_complex = AlphaComplex(points = gen_points)
+
+.. autofunction:: gudhi.datasets.generators.points.sphere
+
+Points on a flat torus
+^^^^^^^^^^^^^^^^^^^^^^
+
+You can also generate points on a torus.
+
+Two functions are available and give the same output: the first one depends on **CGAL** and the second does not and consists of full python code.
+
+On another hand, two sample types are provided: you can either generate i.i.d. points on a d-torus in :math:`R^{2d}` *randomly* or on a *grid*.
+
+First function: **ctorus**
+"""""""""""""""""""""""""""
+
+The user should provide the number of points to be generated on the torus :code:`n_samples`, and the dimension :code:`dim` of the torus on which points would be generated in :math:`R^{2dim}`.
+The :code:`sample` argument is optional and is set to **'random'** by default.
+In this case, the returned generated points would be an array of shape :math:`(n\_samples, 2*dim)`.
+Otherwise, if set to **'grid'**, the points are generated on a grid and would be given as an array of shape:
+
+.. math::
+
+ ( ⌊n\_samples^{1 \over {dim}}⌋^{dim}, 2*dim )
+
+**Note 1:** The output array first shape is rounded down to the closest perfect :math:`dim^{th}` power.
+
+**Note 2:** This version is recommended when the user wishes to use **'grid'** as sample type, or **'random'** with a relatively small number of samples (~ less than 150).
+
+Example
+"""""""
+.. code-block:: python
+
+ from gudhi.datasets.generators import points
+
+ # Generate 50 points randomly on a torus in R^6
+ gen_points = points.ctorus(n_samples = 50, dim = 3)
+
+ # Generate 27 points on a torus as a grid in R^6
+ gen_points = points.ctorus(n_samples = 50, dim = 3, sample = 'grid')
+
+.. autofunction:: gudhi.datasets.generators.points.ctorus
+
+Second function: **torus**
+"""""""""""""""""""""""""""
+
+The user should provide the number of points to be generated on the torus :code:`n_samples` and the dimension :code:`dim` of the torus on which points would be generated in :math:`R^{2dim}`.
+The :code:`sample` argument is optional and is set to **'random'** by default.
+The other allowed value of sample type is **'grid'**.
+
+**Note:** This version is recommended when the user wishes to use **'random'** as sample type with a great number of samples and a low dimension.
+
+Example
+"""""""
+.. code-block:: python
+
+ from gudhi.datasets.generators import points
+
+ # Generate 50 points randomly on a torus in R^6
+ gen_points = points.torus(n_samples = 50, dim = 3)
+
+ # Generate 27 points on a torus as a grid in R^6
+ gen_points = points.torus(n_samples = 50, dim = 3, sample = 'grid')
+
+
+.. autofunction:: gudhi.datasets.generators.points.torus
+
+
+Fetching datasets
+=================
+
+We provide some ready-to-use datasets that are not available by default when getting GUDHI, and need to be fetched explicitly.
+
+By **default**, the fetched datasets directory is set to a folder named **'gudhi_data'** in the **user home folder**.
+Alternatively, it can be set using the **'GUDHI_DATA'** environment variable.
+
+.. autofunction:: gudhi.datasets.remote.fetch_bunny
+
+.. figure:: ./img/bunny.png
+ :figclass: align-center
+
+ 3D Stanford bunny with 35947 vertices.
+
+
+.. autofunction:: gudhi.datasets.remote.fetch_spiral_2d
+
+.. figure:: ./img/spiral_2d.png
+ :figclass: align-center
+
+ 2D spiral with 114562 vertices.
+
+.. autofunction:: gudhi.datasets.remote.clear_data_home
diff --git a/src/python/doc/differentiation_sum.inc b/src/python/doc/differentiation_sum.inc
new file mode 100644
index 00000000..140cf180
--- /dev/null
+++ b/src/python/doc/differentiation_sum.inc
@@ -0,0 +1,12 @@
+.. list-table::
+ :width: 100%
+ :header-rows: 0
+
+ * - :Since: GUDHI 3.6.0
+ - :License: MIT
+ - :Requires: `TensorFlow <installation.html#tensorflow>`_
+
+We provide TensorFlow 2 models that can handle automatic differentiation for the computation of persistence diagrams from complexes available in the Gudhi library.
+This includes simplex trees, cubical complexes and Vietoris-Rips complexes. Detailed example on how to use these layers in practice are available
+in the following `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-optimization.ipynb>`_. Note that even if TensorFlow GPU is enabled, all
+internal computations using Gudhi will be done on CPU.
diff --git a/src/python/doc/examples.rst b/src/python/doc/examples.rst
index 76e5d4c7..1442f185 100644
--- a/src/python/doc/examples.rst
+++ b/src/python/doc/examples.rst
@@ -8,6 +8,7 @@ Examples
.. only:: builder_html
* :download:`alpha_complex_diagram_persistence_from_off_file_example.py <../example/alpha_complex_diagram_persistence_from_off_file_example.py>`
+ * :download:`alpha_complex_from_generated_points_on_sphere_example.py <../example/alpha_complex_from_generated_points_on_sphere_example.py>`
* :download:`alpha_complex_from_points_example.py <../example/alpha_complex_from_points_example.py>`
* :download:`alpha_rips_persistence_bottleneck_distance.py <../example/alpha_rips_persistence_bottleneck_distance.py>`
* :download:`bottleneck_basic_example.py <../example/bottleneck_basic_example.py>`
diff --git a/src/python/doc/img/bunny.png b/src/python/doc/img/bunny.png
new file mode 100644
index 00000000..769aa530
--- /dev/null
+++ b/src/python/doc/img/bunny.png
Binary files differ
diff --git a/src/python/doc/img/sklearn.png b/src/python/doc/img/sklearn.png
new file mode 100644
index 00000000..d1fecbbf
--- /dev/null
+++ b/src/python/doc/img/sklearn.png
Binary files differ
diff --git a/src/python/doc/img/sphere_3d.png b/src/python/doc/img/sphere_3d.png
new file mode 100644
index 00000000..70f3184f
--- /dev/null
+++ b/src/python/doc/img/sphere_3d.png
Binary files differ
diff --git a/src/python/doc/img/spiral_2d.png b/src/python/doc/img/spiral_2d.png
new file mode 100644
index 00000000..abd247cd
--- /dev/null
+++ b/src/python/doc/img/spiral_2d.png
Binary files differ
diff --git a/src/python/doc/img/tensorflow.png b/src/python/doc/img/tensorflow.png
new file mode 100644
index 00000000..a75f3f5b
--- /dev/null
+++ b/src/python/doc/img/tensorflow.png
Binary files differ
diff --git a/src/python/doc/index.rst b/src/python/doc/index.rst
index 040e57a4..35f4ba46 100644
--- a/src/python/doc/index.rst
+++ b/src/python/doc/index.rst
@@ -91,3 +91,8 @@ Clustering
**********
.. include:: clustering.inc
+
+Datasets
+********
+
+.. include:: datasets.inc
diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst
index 66efe45a..7200b2f0 100644
--- a/src/python/doc/installation.rst
+++ b/src/python/doc/installation.rst
@@ -33,25 +33,19 @@ Compiling
These instructions are for people who want to compile gudhi from source, they are
unnecessary if you installed a binary package of Gudhi as above. They assume that
you have downloaded a `release <https://github.com/GUDHI/gudhi-devel/releases>`_,
-with a name like `gudhi.3.2.0.tar.gz`, then run `tar xf gudhi.3.2.0.tar.gz`, which
-created a directory `gudhi.3.2.0`, hereinafter referred to as `/path-to-gudhi/`.
+with a name like `gudhi.3.X.Y.tar.gz`, then run `tar xf gudhi.3.X.Y.tar.gz`, which
+created a directory `gudhi.3.X.Y`, hereinafter referred to as `/path-to-gudhi/`.
If you are instead using a git checkout, beware that the paths are a bit
different, and in particular the `python/` subdirectory is actually `src/python/`
there.
-The library uses c++14 and requires `Boost <https://www.boost.org/>`_ :math:`\geq` 1.56.0,
+The library uses c++17 and requires `Boost <https://www.boost.org/>`_ :math:`\geq` 1.66.0,
`CMake <https://www.cmake.org/>`_ :math:`\geq` 3.5 to generate makefiles,
-`NumPy <http://numpy.org>`_, `Cython <https://www.cython.org/>`_ and
-`pybind11 <https://github.com/pybind/pybind11>`_ to compile
-the GUDHI Python module.
-It is a multi-platform library and compiles on Linux, Mac OSX and Visual
-Studio 2017 or later.
+Python :math:`\geq` 3.5, `NumPy <http://numpy.org>`_ :math:`\geq` 1.15.0, `Cython <https://www.cython.org/>`_
+:math:`\geq` 0.27 and `pybind11 <https://github.com/pybind/pybind11>`_ to compile the GUDHI Python module.
+It is a multi-platform library and compiles on Linux, Mac OSX and Visual Studio 2017 or later.
-On `Windows <https://wiki.python.org/moin/WindowsCompilers>`_ , only Python
-:math:`\geq` 3.5 are available because of the required Visual Studio version.
-
-On other systems, if you have several Python/python installed, the version 2.X
-will be used by default, but you can force it by adding
+If you have several Python/python installed, the version 2.X may be used by default, but you can force it by adding
:code:`-DPython_ADDITIONAL_VERSIONS=3` to the cmake command.
GUDHI Python module compilation
@@ -99,20 +93,14 @@ Or install it definitely in your Python packages folder:
.. code-block:: bash
cd /path-to-gudhi/build/python
- # May require sudo or administrator privileges
- make install
+ python setup.py install # add --user to the command if you do not have the permission
+ # Or 'pip install .'
.. note::
- :code:`make install` is only a
- `CMake custom targets <https://cmake.org/cmake/help/latest/command/add_custom_target.html>`_
- to shortcut :code:`python setup.py install` command.
It does not take into account :code:`CMAKE_INSTALL_PREFIX`.
- But one can use :code:`python setup.py install ...` specific options in the python directory:
-
-.. code-block:: bash
-
- python setup.py install --prefix /home/gudhi # Install in /home/gudhi directory
+ But one can use
+ `alternate location installation <https://docs.python.org/3/install/#alternate-installation>`_.
Test suites
===========
@@ -148,60 +136,71 @@ If :code:`import gudhi` succeeds, please have a look to debug information:
.. code-block:: python
- import gudhi
- print(gudhi.__debug_info__)
+ import gudhi as gd
+ print(gd.__debug_info__)
+ print("+ Installed modules are: " + gd.__available_modules)
+ print("+ Missing modules are: " + gd.__missing_modules)
You shall have something like:
.. code-block:: none
- Python version 2.7.15
- Cython version 0.26.1
- Numpy version 1.14.1
- Eigen3 version 3.1.1
- Installed modules are: off_reader;simplex_tree;rips_complex;
- cubical_complex;periodic_cubical_complex;reader_utils;witness_complex;
- strong_witness_complex;alpha_complex;
- Missing modules are: bottleneck_distance;nerve_gic;subsampling;
- tangential_complex;persistence_graphical_tools;
- euclidean_witness_complex;euclidean_strong_witness_complex;
- CGAL version 4.7.1000
- GMP_LIBRARIES = /usr/lib/x86_64-linux-gnu/libgmp.so
- GMPXX_LIBRARIES = /usr/lib/x86_64-linux-gnu/libgmpxx.so
- TBB version 9107 found and used
+ Pybind11 version 2.8.1
+ Python version 3.7.12
+ Cython version 0.29.25
+ Numpy version 1.21.4
+ Boost version 1.77.0
+ + Installed modules are: off_utils;simplex_tree;rips_complex;cubical_complex;periodic_cubical_complex;
+ persistence_graphical_tools;reader_utils;witness_complex;strong_witness_complex;
+ + Missing modules are: bottleneck;nerve_gic;subsampling;tangential_complex;alpha_complex;euclidean_witness_complex;
+ euclidean_strong_witness_complex;
-Here, you can see that bottleneck_distance, nerve_gic, subsampling and
-tangential_complex are missing because of the CGAL version.
-persistence_graphical_tools is not available as matplotlib is not
-available.
+Here, you can see that the modules that need CGAL are missing, because CGAL is not installed.
+:code:`persistence_graphical_tools` is installed, but
+`its functions <https://gudhi.inria.fr/python/latest/persistence_graphical_tools_ref.html>`_ will produce an error as
+matplotlib is not available.
Unitary tests cannot be run as pytest is missing.
A complete configuration would be :
.. code-block:: none
- Python version 3.6.5
- Cython version 0.28.2
- Pytest version 3.3.2
- Matplotlib version 2.2.2
- Numpy version 1.14.5
- Eigen3 version 3.3.4
- Installed modules are: off_reader;simplex_tree;rips_complex;
- cubical_complex;periodic_cubical_complex;persistence_graphical_tools;
- reader_utils;witness_complex;strong_witness_complex;
- persistence_graphical_tools;bottleneck_distance;nerve_gic;subsampling;
- tangential_complex;alpha_complex;euclidean_witness_complex;
- euclidean_strong_witness_complex;
- CGAL header only version 4.11.0
+ Pybind11 version 2.8.1
+ Python version 3.9.7
+ Cython version 0.29.24
+ Pytest version 6.2.5
+ Matplotlib version 3.5.0
+ Numpy version 1.21.4
+ Scipy version 1.7.3
+ Scikit-learn version 1.0.1
+ POT version 0.8.0
+ HNSWlib found
+ PyKeOps version [pyKeOps]: 2.1
+ EagerPy version 0.30.0
+ TensorFlow version 2.7.0
+ Sphinx version 4.3.0
+ Sphinx-paramlinks version 0.5.2
+ python_docs_theme found
+ Eigen3 version 3.4.0
+ Boost version 1.74.0
+ CGAL version 5.3
GMP_LIBRARIES = /usr/lib/x86_64-linux-gnu/libgmp.so
GMPXX_LIBRARIES = /usr/lib/x86_64-linux-gnu/libgmpxx.so
+ MPFR_LIBRARIES = /usr/lib/x86_64-linux-gnu/libmpfr.so
TBB version 9107 found and used
+ + Installed modules are: bottleneck;off_utils;simplex_tree;rips_complex;cubical_complex;periodic_cubical_complex;
+ persistence_graphical_tools;reader_utils;witness_complex;strong_witness_complex;nerve_gic;subsampling;
+ tangential_complex;alpha_complex;euclidean_witness_complex;euclidean_strong_witness_complex;
+ + Missing modules are:
+
Documentation
=============
-To build the documentation, `sphinx-doc <http://www.sphinx-doc.org>`_ and
-`sphinxcontrib-bibtex <https://sphinxcontrib-bibtex.readthedocs.io>`_ are
+To build the documentation, `sphinx-doc <http://www.sphinx-doc.org>`_,
+`sphinxcontrib-bibtex <https://sphinxcontrib-bibtex.readthedocs.io>`_,
+`sphinxcontrib-paramlinks <https://github.com/sqlalchemyorg/sphinx-paramlinks>`_ and
+`python-docs-theme <https://github.com/python/python-docs-theme>`_ are
required. As the documentation is auto-tested, `CGAL`_, `Eigen`_,
`Matplotlib`_, `NumPy`_, `POT`_, `Scikit-learn`_ and `SciPy`_ are
also mandatory to build the documentation.
@@ -349,8 +348,8 @@ You can still deactivate LaTeX rendering by saying:
.. code-block:: python
- import gudhi
- gudhi.persistence_graphical_tools._gudhi_matplotlib_use_tex=False
+ import gudhi as gd
+ gd.persistence_graphical_tools._gudhi_matplotlib_use_tex=False
PyKeOps
-------
@@ -363,7 +362,7 @@ Python Optimal Transport
------------------------
The :doc:`Wasserstein distance </wasserstein_distance_user>`
-module requires `POT <https://pot.readthedocs.io/>`_, a library that provides
+module requires `POT <https://pythonot.github.io/>`_, a library that provides
several solvers for optimization problems related to Optimal Transport.
PyTorch
@@ -392,18 +391,25 @@ The :doc:`persistence graphical tools </persistence_graphical_tools_user>` and
mathematics, science, and engineering.
:class:`~gudhi.point_cloud.knn.KNearestNeighbors` can use the Python package
-`SciPy <http://scipy.org>`_ as a backend if explicitly requested.
+`SciPy <http://scipy.org>`_ :math:`\geq` 1.6.0 as a backend if explicitly requested.
TensorFlow
----------
-`TensorFlow <https://www.tensorflow.org>`_ is currently only used in some automatic differentiation tests.
+:class:`~gudhi.tensorflow.perslay.Perslay` from the :doc:`persistence representations </representations>` module
+requires `TensorFlow <https://www.tensorflow.org/>`_.
+The :doc:`cubical complex </cubical_complex_tflow_itf_ref>`, :doc:`simplex tree </ls_simplex_tree_tflow_itf_ref>`
+and :doc:`Rips complex </rips_complex_tflow_itf_ref>` modules require `TensorFlow`_
+for incorporating them in neural nets.
+
+`TensorFlow`_ is also used in some automatic differentiation tests.
Bug reports and contributions
*****************************
-Please help us improving the quality of the GUDHI library. You may report bugs or suggestions to:
-
- Contact: gudhi-users@lists.gforge.inria.fr
+Please help us improving the quality of the GUDHI library.
+You may `report bugs <https://github.com/GUDHI/gudhi-devel/issues>`_ or
+`contact us <https://gudhi.inria.fr/contact/>`_ for any suggestions.
-GUDHI is open to external contributions. If you want to join our development team, please contact us.
+GUDHI is open to external contributions. If you want to join our development team, please take some time to read our
+`contributing guide <https://github.com/GUDHI/gudhi-devel/blob/master/.github/CONTRIBUTING.md>`_.
diff --git a/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
new file mode 100644
index 00000000..9d7d633f
--- /dev/null
+++ b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
@@ -0,0 +1,53 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+TensorFlow layer for lower-star persistence on simplex trees
+############################################################
+
+.. include:: differentiation_sum.inc
+
+Example of gradient computed from lower-star filtration of a simplex tree
+-------------------------------------------------------------------------
+
+.. testcode::
+
+ from gudhi.tensorflow import LowerStarSimplexTreeLayer
+ import tensorflow as tf
+ import gudhi as gd
+
+ st = gd.SimplexTree()
+ st.insert([0, 1])
+ st.insert([1, 2])
+ st.insert([2, 3])
+ st.insert([3, 4])
+ st.insert([4, 5])
+ st.insert([5, 6])
+ st.insert([6, 7])
+ st.insert([7, 8])
+ st.insert([8, 9])
+ st.insert([9, 10])
+
+ F = tf.Variable([6.,4.,3.,4.,5.,4.,3.,2.,3.,4.,5.], dtype=tf.float32, trainable=True)
+ sl = LowerStarSimplexTreeLayer(simplextree=st, homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = sl.call(F)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+
+ grads = tape.gradient(loss, [F])
+ print(grads[0].indices.numpy())
+ print(grads[0].values.numpy())
+
+.. testoutput::
+
+ [2 4]
+ [-1. 1.]
+
+Documentation for LowerStarSimplexTreeLayer
+-------------------------------------------
+
+.. autoclass:: gudhi.tensorflow.LowerStarSimplexTreeLayer
+ :members:
+ :special-members: __init__
+ :show-inheritance:
diff --git a/src/python/doc/nerve_gic_complex_user.rst b/src/python/doc/nerve_gic_complex_user.rst
index 0b820abf..8633cadb 100644
--- a/src/python/doc/nerve_gic_complex_user.rst
+++ b/src/python/doc/nerve_gic_complex_user.rst
@@ -12,7 +12,7 @@ Definition
Visualizations of the simplicial complexes can be done with either
neato (from `graphviz <http://www.graphviz.org/>`_),
`geomview <http://www.geomview.org/>`_,
-`KeplerMapper <https://github.com/MLWave/kepler-mapper>`_.
+`KeplerMapper <https://github.com/scikit-tda/kepler-mapper>`_.
Input point clouds are assumed to be OFF files (cf. `OFF file format <fileformats.html#off-file-format>`_).
Covers
diff --git a/src/python/doc/persistence_graphical_tools_user.rst b/src/python/doc/persistence_graphical_tools_user.rst
index d95b9d2b..e1d28c71 100644
--- a/src/python/doc/persistence_graphical_tools_user.rst
+++ b/src/python/doc/persistence_graphical_tools_user.rst
@@ -60,7 +60,7 @@ of shape (N x 2) encoding a persistence diagram (in a given dimension).
import matplotlib.pyplot as plt
import gudhi
import numpy as np
- d = np.array([[0, 1], [1, 2], [1, np.inf]])
+ d = np.array([[0., 1.], [1., 2.], [1., np.inf]])
gudhi.plot_persistence_diagram(d)
plt.show()
diff --git a/src/python/doc/persistent_cohomology_user.rst b/src/python/doc/persistent_cohomology_user.rst
index a3f294b2..39744b95 100644
--- a/src/python/doc/persistent_cohomology_user.rst
+++ b/src/python/doc/persistent_cohomology_user.rst
@@ -6,19 +6,24 @@ Persistent cohomology user manual
=================================
Definition
----------
-===================================== ===================================== =====================================
-:Author: Clément Maria :Since: GUDHI PYTHON 2.0.0 :License: GPL v3
-===================================== ===================================== =====================================
-
-+-----------------------------------------------------------------+-----------------------------------------------------------------------+
-| :doc:`persistent_cohomology_user` | Please refer to each data structure that contains persistence |
-| | feature for reference: |
-| | |
-| | * :doc:`simplex_tree_ref` |
-| | * :doc:`cubical_complex_ref` |
-| | * :doc:`periodic_cubical_complex_ref` |
-+-----------------------------------------------------------------+-----------------------------------------------------------------------+
+.. list-table::
+ :width: 100%
+ :header-rows: 0
+
+ * - :Author: Clément Maria
+ - :Since: GUDHI 2.0.0
+ - :License: MIT
+
+.. list-table::
+ :width: 100%
+ :header-rows: 0
+
+ * - :doc:`persistent_cohomology_user`
+ - Please refer to each data structure that contains persistence feature for reference:
+ * :doc:`simplex_tree_ref`
+ * :doc:`cubical_complex_ref`
+ * :doc:`periodic_cubical_complex_ref`
Computation of persistent cohomology using the algorithm of :cite:`DBLP:journals/dcg/SilvaMV11` and
:cite:`DBLP:conf/compgeom/DeyFW14` and the Compressed Annotation Matrix implementation of
diff --git a/src/python/doc/point_cloud.rst b/src/python/doc/point_cloud.rst
index ffd8f85b..473b303f 100644
--- a/src/python/doc/point_cloud.rst
+++ b/src/python/doc/point_cloud.rst
@@ -13,6 +13,11 @@ File Readers
.. autofunction:: gudhi.read_lower_triangular_matrix_from_csv_file
+File Writers
+------------
+
+.. autofunction:: gudhi.write_points_to_off_file
+
Subsampling
-----------
diff --git a/src/python/doc/representations.rst b/src/python/doc/representations.rst
index b0477197..5686974a 100644
--- a/src/python/doc/representations.rst
+++ b/src/python/doc/representations.rst
@@ -8,10 +8,16 @@ Representations manual
.. include:: representations_sum.inc
-This module, originally available at https://github.com/MathieuCarriere/sklearn-tda and named sklearn_tda, aims at bridging the gap between persistence diagrams and machine learning, by providing implementations of most of the vector representations for persistence diagrams in the literature, in a scikit-learn format. More specifically, it provides tools, using the scikit-learn standard interface, to compute distances and kernels on persistence diagrams, and to convert these diagrams into vectors in Euclidean space.
+This module aims at bridging the gap between persistence diagrams and machine learning, by providing implementations of most of the vector representations for persistence diagrams in the literature, in a scikit-learn format. More specifically, it provides tools, using the scikit-learn standard interface, to compute distances and kernels on persistence diagrams, and to convert these diagrams into vectors in Euclidean space. Moreover, this module also contains `PersLay <http://proceedings.mlr.press/v108/carriere20a.html>`_, which is a general neural network layer for performing deep learning with persistence diagrams, implemented in TensorFlow.
A diagram is represented as a numpy array of shape (n,2), as can be obtained from :func:`~gudhi.SimplexTree.persistence_intervals_in_dimension` for instance. Points at infinity are represented as a numpy array of shape (n,1), storing only the birth time. The classes in this module can handle several persistence diagrams at once. In that case, the diagrams are provided as a list of numpy arrays. Note that it is not necessary for the diagrams to have the same number of points, i.e., for the corresponding arrays to have the same number of rows: all classes can handle arrays with different shapes.
+This `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-representations.ipynb>`__ explains how to
+efficiently combine machine learning and topological data analysis with the
+:doc:`representations module<representations>` in a scikit-learn fashion. This `notebook <https://github.com/MathieuCarriere/tda-tutorials/blob/perslay/Tuto-GUDHI-perslay-expe.ipynb>`__
+and `this one <https://github.com/MathieuCarriere/tda-tutorials/blob/perslay/Tuto-GUDHI-perslay-visu.ipynb>`__ explain how to use PersLay.
+
+
Examples
--------
@@ -30,8 +36,6 @@ This example computes the first two Landscapes associated to a persistence diagr
l=Landscape(num_landscapes=2,resolution=10).fit_transform(diags)
print(l)
-The output is:
-
.. testoutput::
[[1.02851895 2.05703791 2.57129739 1.54277843 0.89995409 1.92847304
@@ -45,13 +49,71 @@ Various kernels
This small example is also provided
:download:`diagram_vectorizations_distances_kernels.py <../example/diagram_vectorizations_distances_kernels.py>`
-Machine Learning and Topological Data Analysis
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+PersLay
+^^^^^^^
-This `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-representations.ipynb>`_ explains how to
-efficiently combine machine learning and topological data analysis with the
-:doc:`representations module<representations>`.
+.. testsetup:: perslay
+
+ import numpy
+ numpy.set_printoptions(precision=5)
+
+.. testcode:: perslay
+
+ import numpy as np
+ import tensorflow as tf
+ from sklearn.preprocessing import MinMaxScaler
+ import gudhi.representations as gdr
+ import gudhi.tensorflow.perslay as prsl
+
+ diagrams = [np.array([[0.,4.],[1.,2.],[3.,8.],[6.,8.]])]
+ diagrams = gdr.DiagramScaler(use=True, scalers=[([0,1], MinMaxScaler())]).fit_transform(diagrams)
+ diagrams = tf.RaggedTensor.from_tensor(tf.constant(diagrams, dtype=tf.float32))
+
+ rho = tf.identity
+ phi = prsl.GaussianPerslayPhi((5, 5), ((-.5, 1.5), (-.5, 1.5)), .1)
+ weight = prsl.PowerPerslayWeight(1.,0.)
+ perm_op = tf.math.reduce_sum
+
+ perslay = prsl.Perslay(phi=phi, weight=weight, perm_op=perm_op, rho=rho)
+ vectors = perslay(diagrams)
+ print(vectors)
+
+.. testcleanup:: perslay
+ numpy.set_printoptions(precision=8)
+
+.. testoutput:: perslay
+
+ tf.Tensor(
+ [[[[1.72661e-16]
+ [4.17060e-09]
+ [1.13369e-08]
+ [8.57388e-12]
+ [2.12439e-14]]
+
+ [[4.17151e-09]
+ [1.00741e-01]
+ [2.73843e-01]
+ [3.07242e-02]
+ [7.61575e-05]]
+
+ [[8.03829e-06]
+ [1.58027e+00]
+ [8.29970e-01]
+ [1.23954e+01]
+ [3.07241e-02]]
+
+ [[8.02694e-06]
+ [1.30657e+00]
+ [9.09230e+00]
+ [6.16648e-02]
+ [1.39492e-06]]
+
+ [[9.03313e-13]
+ [1.49548e-07]
+ [1.51460e-04]
+ [1.02051e-06]
+ [7.80935e-16]]]], shape=(1, 5, 5, 1), dtype=float32)
Preprocessing
-------------
@@ -80,3 +142,44 @@ Metrics
:members:
:special-members:
:show-inheritance:
+
+PersLay
+-------
+.. autoclass:: gudhi.tensorflow.perslay.Perslay
+ :members:
+ :special-members:
+ :show-inheritance:
+
+Weight functions
+^^^^^^^^^^^^^^^^
+.. autoclass:: gudhi.tensorflow.perslay.GaussianMixturePerslayWeight
+ :members:
+ :special-members:
+ :show-inheritance:
+
+.. autoclass:: gudhi.tensorflow.perslay.GridPerslayWeight
+ :members:
+ :special-members:
+ :show-inheritance:
+
+.. autoclass:: gudhi.tensorflow.perslay.PowerPerslayWeight
+ :members:
+ :special-members:
+ :show-inheritance:
+
+Phi functions
+^^^^^^^^^^^^^
+.. autoclass:: gudhi.tensorflow.perslay.FlatPerslayPhi
+ :members:
+ :special-members:
+ :show-inheritance:
+
+.. autoclass:: gudhi.tensorflow.perslay.GaussianPerslayPhi
+ :members:
+ :special-members:
+ :show-inheritance:
+
+.. autoclass:: gudhi.tensorflow.perslay.TentPerslayPhi
+ :members:
+ :special-members:
+ :show-inheritance:
diff --git a/src/python/doc/representations_sum.inc b/src/python/doc/representations_sum.inc
index 4298aea9..f8c31ecf 100644
--- a/src/python/doc/representations_sum.inc
+++ b/src/python/doc/representations_sum.inc
@@ -1,14 +1,16 @@
.. table::
:widths: 30 40 30
- +------------------------------------------------------------------+----------------------------------------------------------------+-------------------------------------------------------------+
- | .. figure:: | Vectorizations, distances and kernels that work on persistence | :Author: Mathieu Carrière, Martin Royer |
- | img/sklearn-tda.png | diagrams, compatible with scikit-learn. | |
- | | | :Since: GUDHI 3.1.0 |
- | | | |
- | | | :License: MIT |
- | | | |
- | | | :Requires: `Scikit-learn <installation.html#scikit-learn>`_ |
- +------------------------------------------------------------------+----------------------------------------------------------------+-------------------------------------------------------------+
- | * :doc:`representations` |
- +------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------+
+ +------------------------------------------------------------------+----------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+
+ | .. figure:: | Vectorizations, distances and kernels that work on persistence | :Author: Mathieu Carrière, Martin Royer, Gard Spreemann, Wojciech Reise |
+ | img/sklearn-tda.png | diagrams, compatible with scikit-learn and tensorflow. | |
+ | | | :Since: GUDHI 3.1.0 |
+ | | | |
+ | | | :License: MIT |
+ | | | |
+ | | | :Requires: `Scikit-learn <installation.html#scikit-learn>`_, `TensorFlow <installation.html#tensorflow>`_ |
+ | | | |
+ +------------------------------------------------------------------+----------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+
+ | * :doc:`representations` |
+ +------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
+
diff --git a/src/python/doc/rips_complex_sum.inc b/src/python/doc/rips_complex_sum.inc
index 2cb24990..2b125e54 100644
--- a/src/python/doc/rips_complex_sum.inc
+++ b/src/python/doc/rips_complex_sum.inc
@@ -11,4 +11,9 @@
| | | |
+----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
| * :doc:`rips_complex_user` | * :doc:`rips_complex_ref` |
- +----------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------+
+ +----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
+ | .. image:: | * :doc:`rips_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
+ | img/tensorflow.png | | |
+ | :target: https://www.tensorflow.org | | |
+ | :height: 30 | | |
+ +----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
diff --git a/src/python/doc/rips_complex_tflow_itf_ref.rst b/src/python/doc/rips_complex_tflow_itf_ref.rst
new file mode 100644
index 00000000..3ce75868
--- /dev/null
+++ b/src/python/doc/rips_complex_tflow_itf_ref.rst
@@ -0,0 +1,48 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+TensorFlow layer for Vietoris-Rips persistence
+##############################################
+
+.. include:: differentiation_sum.inc
+
+Example of gradient computed from Vietoris-Rips persistence
+-----------------------------------------------------------
+
+.. testsetup::
+
+ import numpy
+ numpy.set_printoptions(precision=4)
+
+.. testcode::
+
+ from gudhi.tensorflow import RipsLayer
+ import tensorflow as tf
+
+ X = tf.Variable([[1.,1.],[2.,2.]], dtype=tf.float32, trainable=True)
+ rl = RipsLayer(maximum_edge_length=2., homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = rl.call(X)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+
+ grads = tape.gradient(loss, [X])
+ print(grads[0].numpy())
+
+.. testcleanup::
+
+ numpy.set_printoptions(precision=8)
+
+.. testoutput::
+
+ [[-0.5 -0.5]
+ [ 0.5 0.5]]
+
+Documentation for RipsLayer
+---------------------------
+
+.. autoclass:: gudhi.tensorflow.RipsLayer
+ :members:
+ :special-members: __init__
+ :show-inheritance:
diff --git a/src/python/doc/rips_complex_user.rst b/src/python/doc/rips_complex_user.rst
index 27d218d4..a4e83462 100644
--- a/src/python/doc/rips_complex_user.rst
+++ b/src/python/doc/rips_complex_user.rst
@@ -7,13 +7,7 @@ Rips complex user manual
Definition
----------
-================================================================================ ================================ ======================
-:Authors: Clément Maria, Pawel Dlotko, Vincent Rouvreau, Marc Glisse, Yuichi Ike :Since: GUDHI 2.0.0 :License: GPL v3
-================================================================================ ================================ ======================
-
-+-------------------------------------------+----------------------------------------------------------------------+
-| :doc:`rips_complex_user` | :doc:`rips_complex_ref` |
-+-------------------------------------------+----------------------------------------------------------------------+
+.. include:: rips_complex_sum.inc
The `Rips complex <https://en.wikipedia.org/wiki/Vietoris%E2%80%93Rips_complex>`_ is a simplicial complex that
generalizes proximity (:math:`\varepsilon`-ball) graphs to higher dimensions. The vertices correspond to the input
@@ -40,9 +34,6 @@ A vertex name corresponds to the index of the point in the given range (aka. the
On this example, as edges (4,5), (4,6) and (5,6) are in the complex, simplex (4,5,6) is added with the filtration value
set with :math:`max(filtration(4,5), filtration(4,6), filtration(5,6))`. And so on for simplex (0,1,2,3).
-If the :doc:`RipsComplex <rips_complex_ref>` interfaces are not detailed enough for your need, please refer to
-rips_persistence_step_by_step.cpp C++ example, where the graph construction over the Simplex_tree is more detailed.
-
A Rips complex can easily become huge, even if we limit the length of the edges
and the dimension of the simplices. One easy trick, before building a Rips
complex on a point cloud, is to call :func:`~gudhi.sparsify_point_set` which removes points
@@ -61,6 +52,13 @@ construction of a :class:`~gudhi.RipsComplex` object asks it to build a sparse R
parameter :math:`\varepsilon=0.3`, while the default `sparse=None` builds the
regular Rips complex.
+Another option which is especially useful if you want to compute persistent homology in "high" dimension (2 or more,
+sometimes even 1), is to build the Rips complex only up to dimension 1 (a graph), then use
+:func:`~gudhi.SimplexTree.collapse_edges` to reduce the size of this graph, and finally call
+:func:`~gudhi.SimplexTree.expansion` to get a simplicial complex of a suitable dimension to compute its homology. This
+trick gives the same persistence diagram as one would get with a plain use of `RipsComplex`, with a complex that is
+often significantly smaller and thus faster to process.
+
Point cloud
-----------
@@ -123,54 +121,44 @@ Notice that if we use
asking for a very sparse version (theory only gives some guarantee on the meaning of the output if `sparse<1`),
2 to 5 edges disappear, depending on the random vertex used to start the sparsification.
-Example from OFF file
-^^^^^^^^^^^^^^^^^^^^^
+Example step by step
+^^^^^^^^^^^^^^^^^^^^
-This example builds the :doc:`RipsComplex <rips_complex_ref>` from the given
-points in an OFF file, and max_edge_length value.
-Then it creates a :doc:`SimplexTree <simplex_tree_ref>` with it.
+While :doc:`RipsComplex <rips_complex_ref>` is convenient, for instance to build a simplicial complex in one line
+
+.. testcode::
-Finally, it is asked to display information about the Rips complex.
+ import gudhi
+ points = [[1, 1], [7, 0], [4, 6], [9, 6], [0, 14], [2, 19], [9, 17]]
+ cplx = gudhi.RipsComplex(points=points, max_edge_length=12.0).create_simplex_tree(max_dimension=2)
+you can achieve the same result without this class for more flexibility
.. testcode::
- import gudhi
- off_file = gudhi.__root_source_dir__ + '/data/points/alphacomplexdoc.off'
- point_cloud = gudhi.read_points_from_off_file(off_file = off_file)
- rips_complex = gudhi.RipsComplex(points=point_cloud, max_edge_length=12.0)
- simplex_tree = rips_complex.create_simplex_tree(max_dimension=1)
- result_str = 'Rips complex is of dimension ' + repr(simplex_tree.dimension()) + ' - ' + \
- repr(simplex_tree.num_simplices()) + ' simplices - ' + \
- repr(simplex_tree.num_vertices()) + ' vertices.'
- print(result_str)
- fmt = '%s -> %.2f'
- for filtered_value in simplex_tree.get_filtration():
- print(fmt % tuple(filtered_value))
+ import gudhi
+ from scipy.spatial.distance import cdist
+ points = [[1, 1], [7, 0], [4, 6], [9, 6], [0, 14], [2, 19], [9, 17]]
+ distance_matrix = cdist(points, points)
+ cplx = gudhi.SimplexTree.create_from_array(distance_matrix, max_filtration=12.0)
+ cplx.expansion(2)
-the program output is:
+or
-.. testoutput::
+.. testcode::
+
+ import gudhi
+ from scipy.spatial import cKDTree
+ points = [[1, 1], [7, 0], [4, 6], [9, 6], [0, 14], [2, 19], [9, 17]]
+ tree = cKDTree(points)
+ edges = tree.sparse_distance_matrix(tree, max_distance=12.0, output_type="coo_matrix")
+ cplx = gudhi.SimplexTree()
+ cplx.insert_edges_from_coo_matrix(edges)
+ cplx.expansion(2)
- Rips complex is of dimension 1 - 18 simplices - 7 vertices.
- [0] -> 0.00
- [1] -> 0.00
- [2] -> 0.00
- [3] -> 0.00
- [4] -> 0.00
- [5] -> 0.00
- [6] -> 0.00
- [2, 3] -> 5.00
- [4, 5] -> 5.39
- [0, 2] -> 5.83
- [0, 1] -> 6.08
- [1, 3] -> 6.32
- [1, 2] -> 6.71
- [5, 6] -> 7.28
- [2, 4] -> 8.94
- [0, 3] -> 9.43
- [4, 6] -> 9.49
- [3, 6] -> 11.00
+
+This way, you can easily add a call to :func:`~gudhi.SimplexTree.collapse_edges` before the expansion,
+use a different metric to compute the matrix, or other variations.
Distance matrix
---------------
@@ -229,54 +217,7 @@ until dimension 1 - one skeleton graph in other words), the output is:
[4, 6] -> 9.49
[3, 6] -> 11.00
-Example from csv file
-^^^^^^^^^^^^^^^^^^^^^
-
-This example builds the :doc:`RipsComplex <rips_complex_ref>` from the given
-distance matrix in a csv file, and max_edge_length value.
-Then it creates a :doc:`SimplexTree <simplex_tree_ref>` with it.
-
-Finally, it is asked to display information about the Rips complex.
-
-
-.. testcode::
-
- import gudhi
- distance_matrix = gudhi.read_lower_triangular_matrix_from_csv_file(csv_file=gudhi.__root_source_dir__ + \
- '/data/distance_matrix/full_square_distance_matrix.csv')
- rips_complex = gudhi.RipsComplex(distance_matrix=distance_matrix, max_edge_length=12.0)
- simplex_tree = rips_complex.create_simplex_tree(max_dimension=1)
- result_str = 'Rips complex is of dimension ' + repr(simplex_tree.dimension()) + ' - ' + \
- repr(simplex_tree.num_simplices()) + ' simplices - ' + \
- repr(simplex_tree.num_vertices()) + ' vertices.'
- print(result_str)
- fmt = '%s -> %.2f'
- for filtered_value in simplex_tree.get_filtration():
- print(fmt % tuple(filtered_value))
-
-the program output is:
-
-.. testoutput::
-
- Rips complex is of dimension 1 - 18 simplices - 7 vertices.
- [0] -> 0.00
- [1] -> 0.00
- [2] -> 0.00
- [3] -> 0.00
- [4] -> 0.00
- [5] -> 0.00
- [6] -> 0.00
- [2, 3] -> 5.00
- [4, 5] -> 5.39
- [0, 2] -> 5.83
- [0, 1] -> 6.08
- [1, 3] -> 6.32
- [1, 2] -> 6.71
- [5, 6] -> 7.28
- [2, 4] -> 8.94
- [0, 3] -> 9.43
- [4, 6] -> 9.49
- [3, 6] -> 11.00
+In case this lower triangular matrix is stored in a CSV file, like `data/distance_matrix/full_square_distance_matrix.csv` in the Gudhi distribution, you can read it with :func:`~gudhi.read_lower_triangular_matrix_from_csv_file`.
Correlation matrix
------------------
diff --git a/src/python/doc/simplex_tree_sum.inc b/src/python/doc/simplex_tree_sum.inc
index a8858f16..6b534c9e 100644
--- a/src/python/doc/simplex_tree_sum.inc
+++ b/src/python/doc/simplex_tree_sum.inc
@@ -1,13 +1,18 @@
.. table::
:widths: 30 40 30
- +----------------------------------------------------------------+------------------------------------------------------------------------+-----------------------------+
- | .. figure:: | The simplex tree is an efficient and flexible data structure for | :Author: Clément Maria |
- | ../../doc/Simplex_tree/Simplex_tree_representation.png | representing general (filtered) simplicial complexes. | |
- | :alt: Simplex tree representation | | :Since: GUDHI 2.0.0 |
- | :figclass: align-center | The data structure is described in | |
- | | :cite:`boissonnatmariasimplextreealgorithmica` | :License: MIT |
- | | | |
- +----------------------------------------------------------------+------------------------------------------------------------------------+-----------------------------+
- | * :doc:`simplex_tree_user` | * :doc:`simplex_tree_ref` |
- +----------------------------------------------------------------+------------------------------------------------------------------------------------------------------+
+ +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
+ | .. figure:: | The simplex tree is an efficient and flexible data structure for | :Author: Clément Maria |
+ | ../../doc/Simplex_tree/Simplex_tree_representation.png | representing general (filtered) simplicial complexes. | |
+ | :alt: Simplex tree representation | | :Since: GUDHI 2.0.0 |
+ | :figclass: align-center | The data structure is described in | |
+ | | :cite:`boissonnatmariasimplextreealgorithmica` | :License: MIT |
+ | | | |
+ +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
+ | * :doc:`simplex_tree_user` | * :doc:`simplex_tree_ref` |
+ +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
+ | .. image:: | * :doc:`ls_simplex_tree_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
+ | img/tensorflow.png | | |
+ | :target: https://www.tensorflow.org | | |
+ | :height: 30 | | |
+ +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst
index 9ffc2759..76eb1469 100644
--- a/src/python/doc/wasserstein_distance_user.rst
+++ b/src/python/doc/wasserstein_distance_user.rst
@@ -44,7 +44,7 @@ Basic example
*************
This example computes the 1-Wasserstein distance from 2 persistence diagrams with Euclidean ground metric.
-Note that persistence diagrams must be submitted as (n x 2) numpy arrays and must not contain inf values.
+Note that persistence diagrams must be submitted as (n x 2) numpy arrays.
.. testcode::
@@ -67,14 +67,16 @@ We can also have access to the optimal matching by letting `matching=True`.
It is encoded as a list of indices (i,j), meaning that the i-th point in X
is mapped to the j-th point in Y.
An index of -1 represents the diagonal.
+It handles essential parts (points with infinite coordinates). However if the cardinalities of the essential parts differ,
+any matching has a cost +inf and thus can be considered to be optimal. In such a case, the function returns `(np.inf, None)`.
.. testcode::
import gudhi.wasserstein
import numpy as np
- dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]])
- dgm2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]])
+ dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974], [3, np.inf]])
+ dgm2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1], [4, np.inf]])
cost, matchings = gudhi.wasserstein.wasserstein_distance(dgm1, dgm2, matching=True, order=1, internal_p=2)
message_cost = "Wasserstein distance value = %.2f" %cost
@@ -90,16 +92,31 @@ An index of -1 represents the diagonal.
for j in dgm2_to_diagonal:
print("point %s in dgm2 is matched to the diagonal" %j)
-The output is:
+ # An example where essential part cardinalities differ
+ dgm3 = np.array([[1, 2], [0, np.inf]])
+ dgm4 = np.array([[1, 2], [0, np.inf], [1, np.inf]])
+ cost, matchings = gudhi.wasserstein.wasserstein_distance(dgm3, dgm4, matching=True, order=1, internal_p=2)
+ print("\nSecond example:")
+ print("cost:", cost)
+ print("matchings:", matchings)
+
+
+The output is:
.. testoutput::
- Wasserstein distance value = 2.15
+ Wasserstein distance value = 3.15
point 0 in dgm1 is matched to point 0 in dgm2
point 1 in dgm1 is matched to point 2 in dgm2
+ point 3 in dgm1 is matched to point 3 in dgm2
point 2 in dgm1 is matched to the diagonal
point 1 in dgm2 is matched to the diagonal
+ Second example:
+ cost: inf
+ matchings: None
+
+
Barycenters
-----------
@@ -181,4 +198,4 @@ Tutorial
This
`notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-Barycenters-of-persistence-diagrams.ipynb>`_
-presents the concept of barycenter, or Fréchet mean, of a family of persistence diagrams. \ No newline at end of file
+presents the concept of barycenter, or Fréchet mean, of a family of persistence diagrams.
diff --git a/src/python/example/alpha_complex_diagram_persistence_from_off_file_example.py b/src/python/example/alpha_complex_diagram_persistence_from_off_file_example.py
index 1e0273b3..c96121a6 100755
--- a/src/python/example/alpha_complex_diagram_persistence_from_off_file_example.py
+++ b/src/python/example/alpha_complex_diagram_persistence_from_off_file_example.py
@@ -1,9 +1,7 @@
#!/usr/bin/env python
import argparse
-import errno
-import os
-import gudhi
+import gudhi as gd
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ -
which is released under MIT.
@@ -25,12 +23,12 @@ parser = argparse.ArgumentParser(
description="AlphaComplex creation from " "points read in a OFF file.",
epilog="Example: "
"example/alpha_complex_diagram_persistence_from_off_file_example.py "
- "-f ../data/points/tore3D_300.off -a 0.6"
+ "-f ../data/points/tore3D_300.off"
"- Constructs a alpha complex with the "
"points from the given OFF file.",
)
parser.add_argument("-f", "--file", type=str, required=True)
-parser.add_argument("-a", "--max_alpha_square", type=float, default=0.5)
+parser.add_argument("-a", "--max_alpha_square", type=float, required=False)
parser.add_argument("-b", "--band", type=float, default=0.0)
parser.add_argument(
"--no-diagram",
@@ -41,34 +39,24 @@ parser.add_argument(
args = parser.parse_args()
-with open(args.file, "r") as f:
- first_line = f.readline()
- if (first_line == "OFF\n") or (first_line == "nOFF\n"):
- print("##############################################################")
- print("AlphaComplex creation from points read in a OFF file")
-
- message = "AlphaComplex with max_edge_length=" + repr(args.max_alpha_square)
- print(message)
-
- alpha_complex = gudhi.AlphaComplex(off_file=args.file)
- simplex_tree = alpha_complex.create_simplex_tree(
- max_alpha_square=args.max_alpha_square
- )
-
- message = "Number of simplices=" + repr(simplex_tree.num_simplices())
- print(message)
-
- diag = simplex_tree.persistence()
-
- print("betti_numbers()=")
- print(simplex_tree.betti_numbers())
-
- if args.no_diagram == False:
- import matplotlib.pyplot as plot
- gudhi.plot_persistence_diagram(diag, band=args.band)
- plot.show()
- else:
- raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
- args.file)
-
- f.close()
+print("##############################################################")
+print("AlphaComplex creation from points read in a OFF file")
+
+points = gd.read_points_from_off_file(off_file = args.file)
+alpha_complex = gd.AlphaComplex(points = points)
+if args.max_alpha_square is not None:
+ print("with max_edge_length=", args.max_alpha_square)
+ simplex_tree = alpha_complex.create_simplex_tree(
+ max_alpha_square=args.max_alpha_square
+ )
+else:
+ simplex_tree = alpha_complex.create_simplex_tree()
+
+print("Number of simplices=", simplex_tree.num_simplices())
+
+diag = simplex_tree.persistence()
+print("betti_numbers()=", simplex_tree.betti_numbers())
+if args.no_diagram == False:
+ import matplotlib.pyplot as plot
+ gd.plot_persistence_diagram(diag, band=args.band)
+ plot.show()
diff --git a/src/python/example/alpha_complex_from_generated_points_on_sphere_example.py b/src/python/example/alpha_complex_from_generated_points_on_sphere_example.py
new file mode 100644
index 00000000..3558077e
--- /dev/null
+++ b/src/python/example/alpha_complex_from_generated_points_on_sphere_example.py
@@ -0,0 +1,35 @@
+#!/usr/bin/env python
+
+from gudhi.datasets.generators import _points
+from gudhi import AlphaComplex
+
+
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Hind Montassif
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+__author__ = "Hind Montassif"
+__copyright__ = "Copyright (C) 2021 Inria"
+__license__ = "MIT"
+
+print("#####################################################################")
+print("AlphaComplex creation from generated points on sphere")
+
+
+gen_points = _points.sphere(n_samples = 50, ambient_dim = 2, radius = 1, sample = "random")
+
+# Create an alpha complex
+alpha_complex = AlphaComplex(points = gen_points)
+simplex_tree = alpha_complex.create_simplex_tree()
+
+result_str = 'Alpha complex is of dimension ' + repr(simplex_tree.dimension()) + ' - ' + \
+ repr(simplex_tree.num_simplices()) + ' simplices - ' + \
+ repr(simplex_tree.num_vertices()) + ' vertices.'
+print(result_str)
+
diff --git a/src/python/example/alpha_complex_from_points_example.py b/src/python/example/alpha_complex_from_points_example.py
index 465632eb..5d5ca66a 100755
--- a/src/python/example/alpha_complex_from_points_example.py
+++ b/src/python/example/alpha_complex_from_points_example.py
@@ -19,7 +19,7 @@ __license__ = "MIT"
print("#####################################################################")
print("AlphaComplex creation from points")
alpha_complex = AlphaComplex(points=[[0, 0], [1, 0], [0, 1], [1, 1]])
-simplex_tree = alpha_complex.create_simplex_tree(max_alpha_square=60.0)
+simplex_tree = alpha_complex.create_simplex_tree()
if simplex_tree.find([0, 1]):
print("[0, 1] Found !!")
diff --git a/src/python/example/alpha_rips_persistence_bottleneck_distance.py b/src/python/example/alpha_rips_persistence_bottleneck_distance.py
index 3e12b0d5..6b97fb3b 100755
--- a/src/python/example/alpha_rips_persistence_bottleneck_distance.py
+++ b/src/python/example/alpha_rips_persistence_bottleneck_distance.py
@@ -1,10 +1,8 @@
#!/usr/bin/env python
-import gudhi
+import gudhi as gd
import argparse
import math
-import errno
-import os
import numpy as np
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ -
@@ -37,70 +35,60 @@ parser.add_argument("-t", "--threshold", type=float, default=0.5)
parser.add_argument("-d", "--max_dimension", type=int, default=1)
args = parser.parse_args()
-with open(args.file, "r") as f:
- first_line = f.readline()
- if (first_line == "OFF\n") or (first_line == "nOFF\n"):
- point_cloud = gudhi.read_points_from_off_file(off_file=args.file)
- print("##############################################################")
- print("RipsComplex creation from points read in a OFF file")
+point_cloud = gd.read_points_from_off_file(off_file=args.file)
+print("##############################################################")
+print("RipsComplex creation from points read in a OFF file")
- message = "RipsComplex with max_edge_length=" + repr(args.threshold)
- print(message)
+message = "RipsComplex with max_edge_length=" + repr(args.threshold)
+print(message)
- rips_complex = gudhi.RipsComplex(
- points=point_cloud, max_edge_length=args.threshold
- )
-
- rips_stree = rips_complex.create_simplex_tree(
- max_dimension=args.max_dimension)
-
- message = "Number of simplices=" + repr(rips_stree.num_simplices())
- print(message)
-
- rips_stree.compute_persistence()
-
- print("##############################################################")
- print("AlphaComplex creation from points read in a OFF file")
-
- message = "AlphaComplex with max_edge_length=" + repr(args.threshold)
- print(message)
-
- alpha_complex = gudhi.AlphaComplex(points=point_cloud)
- alpha_stree = alpha_complex.create_simplex_tree(
- max_alpha_square=(args.threshold * args.threshold)
- )
-
- message = "Number of simplices=" + repr(alpha_stree.num_simplices())
- print(message)
+rips_complex = gd.RipsComplex(
+ points=point_cloud, max_edge_length=args.threshold
+)
- alpha_stree.compute_persistence()
+rips_stree = rips_complex.create_simplex_tree(
+ max_dimension=args.max_dimension)
- max_b_distance = 0.0
- for dim in range(args.max_dimension):
- # Alpha persistence values needs to be transform because filtration
- # values are alpha square values
- alpha_intervals = np.sqrt(alpha_stree.persistence_intervals_in_dimension(dim))
+message = "Number of simplices=" + repr(rips_stree.num_simplices())
+print(message)
- rips_intervals = rips_stree.persistence_intervals_in_dimension(dim)
- bottleneck_distance = gudhi.bottleneck_distance(
- rips_intervals, alpha_intervals
- )
- message = (
- "In dimension "
- + repr(dim)
- + ", bottleneck distance = "
- + repr(bottleneck_distance)
- )
- print(message)
- max_b_distance = max(bottleneck_distance, max_b_distance)
+rips_stree.compute_persistence()
- print("==============================================================")
- message = "Bottleneck distance is " + repr(max_b_distance)
- print(message)
+print("##############################################################")
+print("AlphaComplex creation from points read in a OFF file")
- else:
- raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
- args.file)
+message = "AlphaComplex with max_edge_length=" + repr(args.threshold)
+print(message)
+alpha_complex = gd.AlphaComplex(points=point_cloud)
+alpha_stree = alpha_complex.create_simplex_tree(
+ max_alpha_square=(args.threshold * args.threshold)
+)
- f.close()
+message = "Number of simplices=" + repr(alpha_stree.num_simplices())
+print(message)
+
+alpha_stree.compute_persistence()
+
+max_b_distance = 0.0
+for dim in range(args.max_dimension):
+ # Alpha persistence values needs to be transform because filtration
+ # values are alpha square values
+ alpha_intervals = np.sqrt(alpha_stree.persistence_intervals_in_dimension(dim))
+
+ rips_intervals = rips_stree.persistence_intervals_in_dimension(dim)
+ bottleneck_distance = gd.bottleneck_distance(
+ rips_intervals, alpha_intervals
+ )
+ message = (
+ "In dimension "
+ + repr(dim)
+ + ", bottleneck distance = "
+ + repr(bottleneck_distance)
+ )
+ print(message)
+ max_b_distance = max(bottleneck_distance, max_b_distance)
+
+print("==============================================================")
+message = "Bottleneck distance is " + repr(max_b_distance)
+print(message)
diff --git a/src/python/example/plot_alpha_complex.py b/src/python/example/plot_alpha_complex.py
index 99c18a7c..0924619b 100755
--- a/src/python/example/plot_alpha_complex.py
+++ b/src/python/example/plot_alpha_complex.py
@@ -1,8 +1,9 @@
#!/usr/bin/env python
import numpy as np
-import gudhi
-ac = gudhi.AlphaComplex(off_file='../../data/points/tore3D_1307.off')
+import gudhi as gd
+points = gd.read_points_from_off_file(off_file = '../../data/points/tore3D_1307.off')
+ac = gd.AlphaComplex(points = points)
st = ac.create_simplex_tree()
points = np.array([ac.get_point(i) for i in range(st.num_vertices())])
# We want to plot the alpha-complex with alpha=0.1.
diff --git a/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py b/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py
index ea2eb7e1..0b35dbc5 100755
--- a/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py
+++ b/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py
@@ -40,7 +40,7 @@ parser.add_argument(
args = parser.parse_args()
if not (-1.0 < args.min_edge_correlation < 1.0):
- print("Wrong value of the treshold corelation (should be between -1 and 1).")
+ print("Wrong value of the threshold corelation (should be between -1 and 1).")
sys.exit(1)
print("#####################################################################")
diff --git a/src/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py b/src/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py
index 236d085d..8a9cc857 100755
--- a/src/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py
+++ b/src/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py
@@ -21,11 +21,12 @@ parser = argparse.ArgumentParser(
description="RipsComplex creation from " "a distance matrix read in a csv file.",
epilog="Example: "
"example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py "
- "-f ../data/distance_matrix/lower_triangular_distance_matrix.csv -e 12.0 -d 3"
+ "-f ../data/distance_matrix/lower_triangular_distance_matrix.csv -s , -e 12.0 -d 3"
"- Constructs a Rips complex with the "
"distance matrix from the given csv file.",
)
parser.add_argument("-f", "--file", type=str, required=True)
+parser.add_argument("-s", "--separator", type=str, required=True)
parser.add_argument("-e", "--max_edge_length", type=float, default=0.5)
parser.add_argument("-d", "--max_dimension", type=int, default=1)
parser.add_argument("-b", "--band", type=float, default=0.0)
@@ -44,7 +45,7 @@ print("RipsComplex creation from distance matrix read in a csv file")
message = "RipsComplex with max_edge_length=" + repr(args.max_edge_length)
print(message)
-distance_matrix = gudhi.read_lower_triangular_matrix_from_csv_file(csv_file=args.file)
+distance_matrix = gudhi.read_lower_triangular_matrix_from_csv_file(csv_file=args.file, separator=args.separator)
rips_complex = gudhi.RipsComplex(
distance_matrix=distance_matrix, max_edge_length=args.max_edge_length
)
diff --git a/src/python/gudhi/__init__.py.in b/src/python/gudhi/__init__.py.in
index 3043201a..79e12fbc 100644
--- a/src/python/gudhi/__init__.py.in
+++ b/src/python/gudhi/__init__.py.in
@@ -23,10 +23,6 @@ __all__ = [@GUDHI_PYTHON_MODULES@ @GUDHI_PYTHON_MODULES_EXTRA@]
__available_modules = ''
__missing_modules = ''
-# For unitary tests purpose
-# could use "if 'collapse_edges' in gudhi.__all__" when collapse edges will have a python module
-__GUDHI_USE_EIGEN3 = @GUDHI_USE_EIGEN3@
-
# Try to import * from gudhi.__module_name for default modules.
# Extra modules require an explicit import by the user (mostly because of
# unusual dependencies, but also to avoid cluttering namespace gudhi and
diff --git a/src/python/gudhi/alpha_complex.pyx b/src/python/gudhi/alpha_complex.pyx
index ea128743..375e1561 100644
--- a/src/python/gudhi/alpha_complex.pyx
+++ b/src/python/gudhi/alpha_complex.pyx
@@ -16,7 +16,7 @@ from libcpp.utility cimport pair
from libcpp.string cimport string
from libcpp cimport bool
from libc.stdint cimport intptr_t
-import os
+import warnings
from gudhi.simplex_tree cimport *
from gudhi.simplex_tree import SimplexTree
@@ -28,66 +28,76 @@ __license__ = "GPL v3"
cdef extern from "Alpha_complex_interface.h" namespace "Gudhi":
cdef cppclass Alpha_complex_interface "Gudhi::alpha_complex::Alpha_complex_interface":
- Alpha_complex_interface(vector[vector[double]] points, bool fast_version, bool exact_version) nogil except +
+ Alpha_complex_interface(vector[vector[double]] points, vector[double] weights, bool fast_version, bool exact_version) nogil except +
vector[double] get_point(int vertex) nogil except +
void create_simplex_tree(Simplex_tree_interface_full_featured* simplex_tree, double max_alpha_square, bool default_filtration_value) nogil except +
+ @staticmethod
+ void set_float_relative_precision(double precision) nogil
+ @staticmethod
+ double get_float_relative_precision() nogil
# AlphaComplex python interface
cdef class AlphaComplex:
- """AlphaComplex is a simplicial complex constructed from the finite cells
- of a Delaunay Triangulation.
+ """AlphaComplex is a simplicial complex constructed from the finite cells of a Delaunay Triangulation.
- The filtration value of each simplex is computed as the square of the
- circumradius of the simplex if the circumsphere is empty (the simplex is
- then said to be Gabriel), and as the minimum of the filtration values of
- the codimension 1 cofaces that make it not Gabriel otherwise.
+ The filtration value of each simplex is computed as the square of the circumradius of the simplex if the
+ circumsphere is empty (the simplex is then said to be Gabriel), and as the minimum of the filtration values of the
+ codimension 1 cofaces that make it not Gabriel otherwise.
- All simplices that have a filtration value strictly greater than a given
- alpha squared value are not inserted into the complex.
+ All simplices that have a filtration value strictly greater than a given alpha squared value are not inserted into
+ the complex.
.. note::
- When Alpha_complex is constructed with an infinite value of alpha, the
- complex is a Delaunay complex.
-
+ When Alpha_complex is constructed with an infinite value of alpha, the complex is a Delaunay complex.
"""
cdef Alpha_complex_interface * this_ptr
# Fake constructor that does nothing but documenting the constructor
- def __init__(self, points=None, off_file='', precision='safe'):
+ def __init__(self, points=[], off_file='', weights=None, precision='safe'):
"""AlphaComplex constructor.
:param points: A list of points in d-Dimension.
- :type points: list of list of double
-
- Or
+ :type points: Iterable[Iterable[float]]
- :param off_file: An OFF file style name.
+ :param off_file: **[deprecated]** An `OFF file style <fileformats.html#off-file-format>`_ name.
+ If an `off_file` is given with `points` as arguments, only points from the file are taken into account.
:type off_file: string
+ :param weights: A list of weights. If set, the number of weights must correspond to the number of points.
+ :type weights: Iterable[float]
+
:param precision: Alpha complex precision can be 'fast', 'safe' or 'exact'. Default is 'safe'.
:type precision: string
+
+ :raises FileNotFoundError: **[deprecated]** If `off_file` is set but not found.
+ :raises ValueError: In case of inconsistency between the number of points and weights.
"""
# The real cython constructor
- def __cinit__(self, points = None, off_file = '', precision = 'safe'):
+ def __cinit__(self, points = [], off_file = '', weights=None, precision = 'safe'):
assert precision in ['fast', 'safe', 'exact'], "Alpha complex precision can only be 'fast', 'safe' or 'exact'"
cdef bool fast = precision == 'fast'
cdef bool exact = precision == 'exact'
- cdef vector[vector[double]] pts
if off_file:
- if os.path.isfile(off_file):
- points = read_points_from_off_file(off_file = off_file)
- else:
- print("file " + off_file + " not found.")
- if points is None:
- # Empty Alpha construction
- points=[]
+ warnings.warn("off_file is a deprecated parameter, please consider using gudhi.read_points_from_off_file",
+ DeprecationWarning)
+ points = read_points_from_off_file(off_file = off_file)
+
+ # weights are set but is inconsistent with the number of points
+ if weights != None and len(weights) != len(points):
+ raise ValueError("Inconsistency between the number of points and weights")
+
+ # need to copy the points to use them without the gil
+ cdef vector[vector[double]] pts
+ cdef vector[double] wgts
pts = points
+ if weights != None:
+ wgts = weights
with nogil:
- self.this_ptr = new Alpha_complex_interface(pts, fast, exact)
+ self.this_ptr = new Alpha_complex_interface(pts, wgts, fast, exact)
def __dealloc__(self):
if self.this_ptr != NULL:
@@ -127,3 +137,28 @@ cdef class AlphaComplex:
self.this_ptr.create_simplex_tree(<Simplex_tree_interface_full_featured*>stree_int_ptr,
mas, compute_filtration)
return stree
+
+ @staticmethod
+ def set_float_relative_precision(precision):
+ """
+ :param precision: When the AlphaComplex is constructed with :code:`precision = 'safe'` (the default),
+ one can set the float relative precision of filtration values computed in
+ :func:`~gudhi.AlphaComplex.create_simplex_tree`.
+ Default is :code:`1e-5` (cf. :func:`~gudhi.AlphaComplex.get_float_relative_precision`).
+ For more details, please refer to
+ `CGAL::Lazy_exact_nt<NT>::set_relative_precision_of_to_double <https://doc.cgal.org/latest/Number_types/classCGAL_1_1Lazy__exact__nt.html>`_
+ :type precision: float
+ """
+ if precision <=0. or precision >= 1.:
+ raise ValueError("Relative precision value must be strictly greater than 0 and strictly lower than 1")
+ Alpha_complex_interface.set_float_relative_precision(precision)
+
+ @staticmethod
+ def get_float_relative_precision():
+ """
+ :returns: The float relative precision of filtration values computation in
+ :func:`~gudhi.AlphaComplex.create_simplex_tree` when the AlphaComplex is constructed with
+ :code:`precision = 'safe'` (the default).
+ :rtype: float
+ """
+ return Alpha_complex_interface.get_float_relative_precision()
diff --git a/src/python/gudhi/bottleneck.cc b/src/python/gudhi/bottleneck.cc
index 8a3d669a..040e6d37 100644
--- a/src/python/gudhi/bottleneck.cc
+++ b/src/python/gudhi/bottleneck.cc
@@ -9,18 +9,20 @@
*/
#include <gudhi/Bottleneck.h>
-
+#include <optional>
#include <pybind11_diagram_utils.h>
+#include <pybind11/stl.h>
+
+// Indices are added internally in bottleneck_distance, they are not needed in the input.
+static auto make_point(double x, double y, py::ssize_t) { return std::pair(x, y); };
// For compatibility with older versions, we want to support e=None.
-// In C++17, the recommended way is std::optional<double>.
-double bottleneck(Dgm d1, Dgm d2, py::object epsilon)
+double bottleneck(Dgm d1, Dgm d2, std::optional<double> epsilon)
{
- double e = (std::numeric_limits<double>::min)();
- if (!epsilon.is_none()) e = epsilon.cast<double>();
- // I *think* the call to request() has to be before releasing the GIL.
- auto diag1 = numpy_to_range_of_pairs(d1);
- auto diag2 = numpy_to_range_of_pairs(d2);
+ double e = epsilon.value_or((std::numeric_limits<double>::min)());
+ // I *think* the call to request() in numpy_to_range_of_pairs has to be before releasing the GIL.
+ auto diag1 = numpy_to_range_of_pairs(d1, make_point);
+ auto diag2 = numpy_to_range_of_pairs(d2, make_point);
py::gil_scoped_release release;
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index fbba3cc8..d0e9995c 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -271,7 +271,7 @@ class Tomato:
l = self.max_weight_per_cc_.min()
r = self.max_weight_per_cc_.max()
if self.diagram_.size > 0:
- plt.plot(self.diagram_[:, 0], self.diagram_[:, 1], "ro")
+ plt.plot(self.diagram_[:, 0], self.diagram_[:, 1], "o", color="red")
l = min(l, self.diagram_[:, 1].min())
r = max(r, self.diagram_[:, 0].max())
if l == r:
@@ -283,7 +283,7 @@ class Tomato:
l, r = -1.0, 1.0
plt.plot([l, r], [l, r])
plt.plot(
- self.max_weight_per_cc_, numpy.full(self.max_weight_per_cc_.shape, 1.1 * l - 0.1 * r), "ro", color="green"
+ self.max_weight_per_cc_, numpy.full(self.max_weight_per_cc_.shape, 1.1 * l - 0.1 * r), "o", color="green"
)
plt.show()
diff --git a/src/python/gudhi/cubical_complex.pyx b/src/python/gudhi/cubical_complex.pyx
index 28fbe3af..8e244bb8 100644
--- a/src/python/gudhi/cubical_complex.pyx
+++ b/src/python/gudhi/cubical_complex.pyx
@@ -35,7 +35,7 @@ cdef extern from "Cubical_complex_interface.h" namespace "Gudhi":
cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi":
cdef cppclass Cubical_complex_persistence_interface "Gudhi::Persistent_cohomology_interface<Gudhi::Cubical_complex::Cubical_complex_interface<>>":
Cubical_complex_persistence_interface(Bitmap_cubical_complex_base_interface * st, bool persistence_dim_max) nogil
- void compute_persistence(int homology_coeff_field, double min_persistence) nogil
+ void compute_persistence(int homology_coeff_field, double min_persistence) nogil except+
vector[pair[int, pair[double, double]]] get_persistence() nogil
vector[vector[int]] cofaces_of_cubical_persistence_pairs() nogil
vector[int] betti_numbers() nogil
@@ -147,7 +147,7 @@ cdef class CubicalComplex:
:func:`persistence` returns.
:param homology_coeff_field: The homology coefficient field. Must be a
- prime number
+ prime number. Default value is 11. Max is 46337.
:type homology_coeff_field: int.
:param min_persistence: The minimum persistence value to take into
account (strictly greater than min_persistence). Default value is
@@ -169,7 +169,7 @@ cdef class CubicalComplex:
"""This function computes and returns the persistence of the complex.
:param homology_coeff_field: The homology coefficient field. Must be a
- prime number
+ prime number. Default value is 11. Max is 46337.
:type homology_coeff_field: int.
:param min_persistence: The minimum persistence value to take into
account (strictly greater than min_persistence). Default value is
@@ -281,4 +281,8 @@ cdef class CubicalComplex:
launched first.
"""
assert self.pcohptr != NULL, "compute_persistence() must be called before persistence_intervals_in_dimension()"
- return np.array(self.pcohptr.intervals_in_dimension(dimension))
+ piid = np.array(self.pcohptr.intervals_in_dimension(dimension))
+ # Workaround https://github.com/GUDHI/gudhi-devel/issues/507
+ if len(piid) == 0:
+ return np.empty(shape = [0, 2])
+ return piid
diff --git a/src/python/gudhi/datasets/__init__.py b/src/python/gudhi/datasets/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/python/gudhi/datasets/__init__.py
diff --git a/src/python/gudhi/datasets/generators/__init__.py b/src/python/gudhi/datasets/generators/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/python/gudhi/datasets/generators/__init__.py
diff --git a/src/python/gudhi/datasets/generators/_points.cc b/src/python/gudhi/datasets/generators/_points.cc
new file mode 100644
index 00000000..82fea25b
--- /dev/null
+++ b/src/python/gudhi/datasets/generators/_points.cc
@@ -0,0 +1,121 @@
+/* This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ * See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ * Author(s): Hind Montassif
+ *
+ * Copyright (C) 2021 Inria
+ *
+ * Modification(s):
+ * - YYYY/MM Author: Description of the modification
+ */
+
+#include <pybind11/pybind11.h>
+#include <pybind11/numpy.h>
+
+#include <gudhi/random_point_generators.h>
+#include <gudhi/Debug_utils.h>
+
+#include <CGAL/Epick_d.h>
+
+namespace py = pybind11;
+
+
+typedef CGAL::Epick_d< CGAL::Dynamic_dimension_tag > Kern;
+
+py::array_t<double> generate_points_on_sphere(size_t n_samples, int ambient_dim, double radius, std::string sample) {
+
+ if (sample != "random") {
+ throw pybind11::value_error("This sample type is not supported");
+ }
+
+ py::array_t<double> points({n_samples, (size_t)ambient_dim});
+
+ py::buffer_info buf = points.request();
+ double *ptr = static_cast<double *>(buf.ptr);
+
+ GUDHI_CHECK(n_samples == buf.shape[0], "Py array first dimension not matching n_samples on sphere");
+ GUDHI_CHECK(ambient_dim == buf.shape[1], "Py array second dimension not matching the ambient space dimension");
+
+
+ std::vector<typename Kern::Point_d> points_generated;
+
+ {
+ py::gil_scoped_release release;
+ points_generated = Gudhi::generate_points_on_sphere_d<Kern>(n_samples, ambient_dim, radius);
+ }
+
+ for (size_t i = 0; i < n_samples; i++)
+ for (int j = 0; j < ambient_dim; j++)
+ ptr[i*ambient_dim+j] = points_generated[i][j];
+
+ return points;
+}
+
+py::array_t<double> generate_points_on_torus(size_t n_samples, int dim, std::string sample) {
+
+ if ( (sample != "random") && (sample != "grid")) {
+ throw pybind11::value_error("This sample type is not supported");
+ }
+
+ std::vector<typename Kern::Point_d> points_generated;
+
+ {
+ py::gil_scoped_release release;
+ points_generated = Gudhi::generate_points_on_torus_d<Kern>(n_samples, dim, sample);
+ }
+
+ size_t npoints = points_generated.size();
+
+ GUDHI_CHECK(2*dim == points_generated[0].size(), "Py array second dimension not matching the double torus dimension");
+
+ py::array_t<double> points({npoints, (size_t)2*dim});
+
+ py::buffer_info buf = points.request();
+ double *ptr = static_cast<double *>(buf.ptr);
+
+ for (size_t i = 0; i < npoints; i++)
+ for (int j = 0; j < 2*dim; j++)
+ ptr[i*(2*dim)+j] = points_generated[i][j];
+
+ return points;
+}
+
+PYBIND11_MODULE(_points, m) {
+ m.attr("__license__") = "LGPL v3";
+
+ m.def("sphere", &generate_points_on_sphere,
+ py::arg("n_samples"), py::arg("ambient_dim"),
+ py::arg("radius") = 1., py::arg("sample") = "random",
+ R"pbdoc(
+ Generate random i.i.d. points uniformly on a (d-1)-sphere in R^d
+
+ :param n_samples: The number of points to be generated.
+ :type n_samples: integer
+ :param ambient_dim: The ambient dimension d.
+ :type ambient_dim: integer
+ :param radius: The radius. Default value is `1.`.
+ :type radius: float
+ :param sample: The sample type. Default and only available value is `"random"`.
+ :type sample: string
+ :returns: the generated points on a sphere.
+ )pbdoc");
+
+ m.def("ctorus", &generate_points_on_torus,
+ py::arg("n_samples"), py::arg("dim"), py::arg("sample") = "random",
+ R"pbdoc(
+ Generate random i.i.d. points on a d-torus in R^2d or as a grid
+
+ :param n_samples: The number of points to be generated.
+ :type n_samples: integer
+ :param dim: The dimension of the torus on which points would be generated in R^2*dim.
+ :type dim: integer
+ :param sample: The sample type. Available values are: `"random"` and `"grid"`. Default value is `"random"`.
+ :type sample: string
+ :returns: the generated points on a torus.
+
+ The shape of returned numpy array is:
+
+ If sample is 'random': (n_samples, 2*dim).
+
+ If sample is 'grid': (⌊n_samples**(1./dim)⌋**dim, 2*dim), where shape[0] is rounded down to the closest perfect 'dim'th power.
+ )pbdoc");
+}
diff --git a/src/python/gudhi/datasets/generators/points.py b/src/python/gudhi/datasets/generators/points.py
new file mode 100644
index 00000000..9bb2799d
--- /dev/null
+++ b/src/python/gudhi/datasets/generators/points.py
@@ -0,0 +1,59 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Hind Montassif
+#
+# Copyright (C) 2021 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+import numpy as np
+
+from ._points import ctorus
+from ._points import sphere
+
+def _generate_random_points_on_torus(n_samples, dim):
+
+ # Generate random angles of size n_samples*dim
+ alpha = 2*np.pi*np.random.rand(n_samples*dim)
+
+ # Based on angles, construct points of size n_samples*dim on a circle and reshape the result in a n_samples*2*dim array
+ array_points = np.column_stack([np.cos(alpha), np.sin(alpha)]).reshape(-1, 2*dim)
+
+ return array_points
+
+def _generate_grid_points_on_torus(n_samples, dim):
+
+ # Generate points on a dim-torus as a grid
+ n_samples_grid = int((n_samples+.5)**(1./dim)) # add .5 to avoid rounding down with numerical approximations
+ alpha = np.linspace(0, 2*np.pi, n_samples_grid, endpoint=False)
+
+ array_points = np.column_stack([np.cos(alpha), np.sin(alpha)])
+ array_points_idx = np.empty([n_samples_grid]*dim + [dim], dtype=int)
+ for i, x in enumerate(np.ix_(*([np.arange(n_samples_grid)]*dim))):
+ array_points_idx[...,i] = x
+ return array_points[array_points_idx].reshape(-1, 2*dim)
+
+def torus(n_samples, dim, sample='random'):
+ """
+ Generate points on a flat dim-torus in R^2dim either randomly or on a grid
+
+ :param n_samples: The number of points to be generated.
+ :param dim: The dimension of the torus on which points would be generated in R^2*dim.
+ :param sample: The sample type of the generated points. Can be 'random' or 'grid'.
+ :returns: numpy array containing the generated points on a torus.
+
+ The shape of returned numpy array is:
+
+ If sample is 'random': (n_samples, 2*dim).
+
+ If sample is 'grid': (⌊n_samples**(1./dim)⌋**dim, 2*dim), where shape[0] is rounded down to the closest perfect 'dim'th power.
+ """
+ if sample == 'random':
+ # Generate points randomly
+ return _generate_random_points_on_torus(n_samples, dim)
+ elif sample == 'grid':
+ # Generate points on a grid
+ return _generate_grid_points_on_torus(n_samples, dim)
+ else:
+ raise ValueError("Sample type '{}' is not supported".format(sample))
diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py
new file mode 100644
index 00000000..f6d3fe56
--- /dev/null
+++ b/src/python/gudhi/datasets/remote.py
@@ -0,0 +1,223 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Hind Montassif
+#
+# Copyright (C) 2021 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+from os.path import join, split, exists, expanduser
+from os import makedirs, remove, environ
+
+from urllib.request import urlretrieve
+import hashlib
+import shutil
+
+import numpy as np
+
+def _get_data_home(data_home = None):
+ """
+ Return the path of the remote datasets directory.
+ This folder is used to store remotely fetched datasets.
+ By default the datasets directory is set to a folder named 'gudhi_data' in the user home folder.
+ Alternatively, it can be set by the 'GUDHI_DATA' environment variable.
+ The '~' symbol is expanded to the user home folder.
+ If the folder does not already exist, it is automatically created.
+
+ Parameters
+ ----------
+ data_home : string
+ The path to remote datasets directory.
+ Default is `None`, meaning that the data home directory will be set to "~/gudhi_data",
+ if the 'GUDHI_DATA' environment variable does not exist.
+
+ Returns
+ -------
+ data_home: string
+ The path to remote datasets directory.
+ """
+ if data_home is None:
+ data_home = environ.get("GUDHI_DATA", join("~", "gudhi_data"))
+ data_home = expanduser(data_home)
+ makedirs(data_home, exist_ok=True)
+ return data_home
+
+
+def clear_data_home(data_home = None):
+ """
+ Delete the data home cache directory and all its content.
+
+ Parameters
+ ----------
+ data_home : string, default is None.
+ The path to remote datasets directory.
+ If `None` and the 'GUDHI_DATA' environment variable does not exist,
+ the default directory to be removed is set to "~/gudhi_data".
+ """
+ data_home = _get_data_home(data_home)
+ shutil.rmtree(data_home)
+
+def _checksum_sha256(file_path):
+ """
+ Compute the file checksum using sha256.
+
+ Parameters
+ ----------
+ file_path: string
+ Full path of the created file including filename.
+
+ Returns
+ -------
+ The hex digest of file_path.
+ """
+ sha256_hash = hashlib.sha256()
+ chunk_size = 4096
+ with open(file_path,"rb") as f:
+ # Read and update hash string value in blocks of 4K
+ while True:
+ buffer = f.read(chunk_size)
+ if not buffer:
+ break
+ sha256_hash.update(buffer)
+ return sha256_hash.hexdigest()
+
+def _fetch_remote(url, file_path, file_checksum = None):
+ """
+ Fetch the wanted dataset from the given url and save it in file_path.
+
+ Parameters
+ ----------
+ url : string
+ The url to fetch the dataset from.
+ file_path : string
+ Full path of the downloaded file including filename.
+ file_checksum : string
+ The file checksum using sha256 to check against the one computed on the downloaded file.
+ Default is 'None', which means the checksum is not checked.
+
+ Raises
+ ------
+ IOError
+ If the computed SHA256 checksum of file does not match the one given by the user.
+ """
+
+ # Get the file
+ urlretrieve(url, file_path)
+
+ if file_checksum is not None:
+ checksum = _checksum_sha256(file_path)
+ if file_checksum != checksum:
+ # Remove file and raise error
+ remove(file_path)
+ raise IOError("{} has a SHA256 checksum : {}, "
+ "different from expected : {}."
+ "The file may be corrupted or the given url may be wrong !".format(file_path, checksum, file_checksum))
+
+def _get_archive_path(file_path, label):
+ """
+ Get archive path based on file_path given by user and label.
+
+ Parameters
+ ----------
+ file_path: string
+ Full path of the file to get including filename, or None.
+ label: string
+ Label used along with 'data_home' to get archive path, in case 'file_path' is None.
+
+ Returns
+ -------
+ Full path of archive including filename.
+ """
+ if file_path is None:
+ archive_path = join(_get_data_home(), label)
+ dirname = split(archive_path)[0]
+ makedirs(dirname, exist_ok=True)
+ else:
+ archive_path = file_path
+ dirname = split(archive_path)[0]
+ makedirs(dirname, exist_ok=True)
+
+ return archive_path
+
+def fetch_spiral_2d(file_path = None):
+ """
+ Load the spiral_2d dataset.
+
+ Note that if the dataset already exists in the target location, it is not downloaded again,
+ and the corresponding array is returned from cache.
+
+ Parameters
+ ----------
+ file_path : string
+ Full path of the downloaded file including filename.
+
+ Default is None, meaning that it's set to "data_home/points/spiral_2d/spiral_2d.npy".
+
+ The "data_home" directory is set by default to "~/gudhi_data",
+ unless the 'GUDHI_DATA' environment variable is set.
+
+ Returns
+ -------
+ points: numpy array
+ Array of shape (114562, 2).
+ """
+ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy"
+ file_checksum = '2226024da76c073dd2f24b884baefbfd14928b52296df41ad2d9b9dc170f2401'
+
+ archive_path = _get_archive_path(file_path, "points/spiral_2d/spiral_2d.npy")
+
+ if not exists(archive_path):
+ _fetch_remote(file_url, archive_path, file_checksum)
+
+ return np.load(archive_path, mmap_mode='r')
+
+def fetch_bunny(file_path = None, accept_license = False):
+ """
+ Load the Stanford bunny dataset.
+
+ This dataset contains 35947 vertices.
+
+ Note that if the dataset already exists in the target location, it is not downloaded again,
+ and the corresponding array is returned from cache.
+
+ Parameters
+ ----------
+ file_path : string
+ Full path of the downloaded file including filename.
+
+ Default is None, meaning that it's set to "data_home/points/bunny/bunny.npy".
+ In this case, the LICENSE file would be downloaded as "data_home/points/bunny/bunny.LICENSE".
+
+ The "data_home" directory is set by default to "~/gudhi_data",
+ unless the 'GUDHI_DATA' environment variable is set.
+
+ accept_license : boolean
+ Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms.
+
+ Default is False.
+
+ Returns
+ -------
+ points: numpy array
+ Array of shape (35947, 3).
+ """
+
+ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy"
+ file_checksum = 'f382482fd89df8d6444152dc8fd454444fe597581b193fd139725a85af4a6c6e'
+ license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.LICENSE"
+ license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a'
+
+ archive_path = _get_archive_path(file_path, "points/bunny/bunny.npy")
+
+ if not exists(archive_path):
+ _fetch_remote(file_url, archive_path, file_checksum)
+ license_path = join(split(archive_path)[0], "bunny.LICENSE")
+ _fetch_remote(license_url, license_path, license_checksum)
+ # Print license terms unless accept_license is set to True
+ if not accept_license:
+ if exists(license_path):
+ with open(license_path, 'r') as f:
+ print(f.read())
+
+ return np.load(archive_path, mmap_mode='r')
diff --git a/src/python/gudhi/hera/bottleneck.cc b/src/python/gudhi/hera/bottleneck.cc
index 0cb562ce..9826252c 100644
--- a/src/python/gudhi/hera/bottleneck.cc
+++ b/src/python/gudhi/hera/bottleneck.cc
@@ -16,13 +16,16 @@
using py::ssize_t;
#endif
-#include <bottleneck.h> // Hera
+#include <hera/bottleneck.h>
+
+// Indices are added internally in bottleneck_distance, they are not needed in the input.
+static auto make_point(double x, double y, py::ssize_t) { return std::pair(x, y); };
double bottleneck_distance(Dgm d1, Dgm d2, double delta)
{
- // I *think* the call to request() has to be before releasing the GIL.
- auto diag1 = numpy_to_range_of_pairs(d1);
- auto diag2 = numpy_to_range_of_pairs(d2);
+ // I *think* the call to request() in numpy_to_range_of_pairs has to be before releasing the GIL.
+ auto diag1 = numpy_to_range_of_pairs(d1, make_point);
+ auto diag2 = numpy_to_range_of_pairs(d2, make_point);
py::gil_scoped_release release;
diff --git a/src/python/gudhi/hera/wasserstein.cc b/src/python/gudhi/hera/wasserstein.cc
index 1a21f02f..41e84f7b 100644
--- a/src/python/gudhi/hera/wasserstein.cc
+++ b/src/python/gudhi/hera/wasserstein.cc
@@ -8,29 +8,126 @@
* - YYYY/MM Author: Description of the modification
*/
-#include <wasserstein.h> // Hera
-
#include <pybind11_diagram_utils.h>
-double wasserstein_distance(
+#ifdef _MSC_VER
+// https://github.com/grey-narn/hera/issues/3
+// ssize_t is a non-standard type (well, posix)
+using py::ssize_t;
+#endif
+
+#include <hera/wasserstein.h>
+#include <gudhi/Debug_utils.h>
+
+// Unlike bottleneck, for wasserstein, we need to add the index ourselves (if we want the matching)
+static auto make_hera_point(double x, double y, py::ssize_t i) { return hera::DiagramPoint<double>(x, y, i); };
+
+py::object wasserstein_distance(
Dgm d1, Dgm d2,
double wasserstein_power, double internal_p,
- double delta)
+ double delta, bool return_matching)
{
- // I *think* the call to request() has to be before releasing the GIL.
- auto diag1 = numpy_to_range_of_pairs(d1);
- auto diag2 = numpy_to_range_of_pairs(d2);
-
- py::gil_scoped_release release;
-
- hera::AuctionParams<double> params;
- params.wasserstein_power = wasserstein_power;
- // hera encodes infinity as -1...
- if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>();
- params.internal_p = internal_p;
- params.delta = delta;
- // The extra parameters are purposedly not exposed for now.
- return hera::wasserstein_dist(diag1, diag2, params);
+ // I *think* the call to request() in numpy_to_range_of_pairs has to be before releasing the GIL.
+ auto diag1 = numpy_to_range_of_pairs(d1, make_hera_point);
+ auto diag2 = numpy_to_range_of_pairs(d2, make_hera_point);
+ int n1 = boost::size(diag1);
+ int n2 = boost::size(diag2);
+ hera::AuctionResult<double> res;
+ double dist;
+
+ { // No Python allowed in this section
+ py::gil_scoped_release release;
+
+ hera::AuctionParams<double> params;
+ params.wasserstein_power = wasserstein_power;
+ // hera encodes infinity as -1...
+ if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>();
+ params.internal_p = internal_p;
+ params.delta = delta;
+ if(return_matching) {
+ params.return_matching = true;
+ params.match_inf_points = true;
+ }
+ // The extra parameters are purposely not exposed for now.
+ res = hera::wasserstein_cost_detailed(diag1, diag2, params);
+ dist = std::pow(res.cost, 1./params.wasserstein_power);
+ }
+
+ if(!return_matching)
+ return py::cast(dist);
+
+ if(dist == std::numeric_limits<double>::infinity())
+ return py::make_tuple(dist, py::none());
+
+ // bug in Hera, matching_a_to_b_ is empty if one diagram is empty or both diagrams contain the same points
+ if(res.matching_a_to_b_.size() == 0) {
+ if(n1 == 0) { // diag1 is empty
+ py::array_t<int> matching({{ n2, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ for(int j=0; j<n2; ++j){
+ m(j, 0) = -1;
+ m(j, 1) = j;
+ }
+ return py::make_tuple(dist, matching);
+ }
+ if(n2 == 0) { // diag2 is empty
+ py::array_t<int> matching({{ n1, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ for(int i=0; i<n1; ++i){
+ m(i, 0) = i;
+ m(i, 1) = -1;
+ }
+ return py::make_tuple(dist, matching);
+ }
+ // The only remaining case should be that the 2 diagrams are identical, but possibly shuffled
+ GUDHI_CHECK(n1==n2, "unexpected bug in Hera?");
+ std::vector v1(boost::begin(diag1), boost::end(diag1));
+ std::vector v2(boost::begin(diag2), boost::end(diag2));
+ std::sort(v1.begin(), v1.end());
+ std::sort(v2.begin(), v2.end());
+ py::array_t<int> matching({{ n1, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ for(int i=0; i<n1; ++i){
+ GUDHI_CHECK(v1[i][0]==v2[i][0] && v1[i][1]==v2[i][1], "unexpected bug in Hera?");
+ m(i, 0) = v1[i].get_id();
+ m(i, 1) = v2[i].get_id();
+ }
+ return py::make_tuple(dist, matching);
+
+ }
+
+ // bug in Hera, diagonal points are ignored and don't appear in matching_a_to_b_
+ for(auto p : diag1)
+ if(p[0] == p[1]) { auto id = p.get_id(); res.matching_a_to_b_[id] = -id-1; }
+ for(auto p : diag2)
+ if(p[0] == p[1]) { auto id = p.get_id(); res.matching_a_to_b_[-id-1] = id; }
+
+ py::array_t<int> matching({{ n1 + n2, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ int cur = 0;
+ for(auto x : res.matching_a_to_b_){
+ if(x.first < 0) {
+ if(x.second < 0) {
+ } else {
+ m(cur, 0) = -1;
+ m(cur, 1) = x.second;
+ ++cur;
+ }
+ } else {
+ if(x.second < 0) {
+ m(cur, 0) = x.first;
+ m(cur, 1) = -1;
+ ++cur;
+ } else {
+ m(cur, 0) = x.first;
+ m(cur, 1) = x.second;
+ ++cur;
+ }
+ }
+ }
+ // n1+n2 was too much, it only happens if everything matches to the diagonal, so we return matching[:cur,:]
+ py::array_t<int> ret({{ cur, 2 }}, {{ matching.strides()[0], matching.strides()[1] }}, matching.data(), matching);
+ return py::make_tuple(dist, ret);
}
PYBIND11_MODULE(wasserstein, m) {
@@ -39,6 +136,7 @@ PYBIND11_MODULE(wasserstein, m) {
py::arg("order") = 1,
py::arg("internal_p") = std::numeric_limits<double>::infinity(),
py::arg("delta") = .01,
+ py::arg("matching") = false,
R"pbdoc(
Compute the Wasserstein distance between two diagrams.
Points at infinity are supported.
@@ -49,8 +147,9 @@ PYBIND11_MODULE(wasserstein, m) {
order (float): Wasserstein exponent W_q
internal_p (float): Internal Minkowski norm L^p in R^2
delta (float): Relative error 1+delta
+ matching (bool): if ``True``, computes and returns the optimal matching between X and Y, encoded as a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to the j-th point in Y, with the convention that (-1) represents the diagonal. If the distance between two diagrams is +inf (which happens if the cardinalities of essential parts differ) and the matching is requested, it will be set to ``None`` (any matching is optimal).
Returns:
- float: Approximate Wasserstein distance W_q(X,Y)
+ float|Tuple[float,numpy.array|None]: Approximate Wasserstein distance W_q(X,Y), and optionally the corresponding matching
)pbdoc");
}
diff --git a/src/python/gudhi/off_reader.pyx b/src/python/gudhi/off_utils.pyx
index a3200704..9276c7b0 100644
--- a/src/python/gudhi/off_reader.pyx
+++ b/src/python/gudhi/off_utils.pyx
@@ -13,8 +13,10 @@ from __future__ import print_function
from cython cimport numeric
from libcpp.vector cimport vector
from libcpp.string cimport string
+cimport cython
import errno
import os
+import numpy as np
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2016 Inria"
@@ -24,7 +26,7 @@ cdef extern from "Off_reader_interface.h" namespace "Gudhi":
vector[vector[double]] read_points_from_OFF_file(string off_file)
def read_points_from_off_file(off_file=''):
- """Read points from OFF file.
+ """Read points from an `OFF file <fileformats.html#off-file-format>`_.
:param off_file: An OFF file style name.
:type off_file: string
@@ -39,3 +41,22 @@ def read_points_from_off_file(off_file=''):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
off_file)
+@cython.embedsignature(True)
+def write_points_to_off_file(fname, points):
+ """Write points to an `OFF file <fileformats.html#off-file-format>`_.
+
+ A simple wrapper for `numpy.savetxt`.
+
+ :param fname: Name of the OFF file.
+ :type fname: str or file handle
+ :param points: Point coordinates.
+ :type points: numpy array of shape (n, dim)
+ """
+ points = np.array(points, copy=False)
+ assert len(points.shape) == 2
+ dim = points.shape[1]
+ if dim == 3:
+ head = 'OFF\n{} 0 0'.format(points.shape[0])
+ else:
+ head = 'nOFF\n{} {} 0 0'.format(dim, points.shape[0])
+ np.savetxt(fname, points, header=head, comments='')
diff --git a/src/python/gudhi/periodic_cubical_complex.pyx b/src/python/gudhi/periodic_cubical_complex.pyx
index d353d2af..6c21e902 100644
--- a/src/python/gudhi/periodic_cubical_complex.pyx
+++ b/src/python/gudhi/periodic_cubical_complex.pyx
@@ -32,7 +32,7 @@ cdef extern from "Cubical_complex_interface.h" namespace "Gudhi":
cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi":
cdef cppclass Periodic_cubical_complex_persistence_interface "Gudhi::Persistent_cohomology_interface<Gudhi::Cubical_complex::Cubical_complex_interface<Gudhi::cubical_complex::Bitmap_cubical_complex_periodic_boundary_conditions_base<double>>>":
Periodic_cubical_complex_persistence_interface(Periodic_cubical_complex_base_interface * st, bool persistence_dim_max) nogil
- void compute_persistence(int homology_coeff_field, double min_persistence) nogil
+ void compute_persistence(int homology_coeff_field, double min_persistence) nogil except +
vector[pair[int, pair[double, double]]] get_persistence() nogil
vector[vector[int]] cofaces_of_cubical_persistence_pairs() nogil
vector[int] betti_numbers() nogil
@@ -148,7 +148,7 @@ cdef class PeriodicCubicalComplex:
:func:`persistence` returns.
:param homology_coeff_field: The homology coefficient field. Must be a
- prime number
+ prime number. Default value is 11. Max is 46337.
:type homology_coeff_field: int.
:param min_persistence: The minimum persistence value to take into
account (strictly greater than min_persistence). Default value is
@@ -170,7 +170,7 @@ cdef class PeriodicCubicalComplex:
"""This function computes and returns the persistence of the complex.
:param homology_coeff_field: The homology coefficient field. Must be a
- prime number
+ prime number. Default value is 11. Max is 46337.
:type homology_coeff_field: int.
:param min_persistence: The minimum persistence value to take into
account (strictly greater than min_persistence). Default value is
@@ -280,4 +280,8 @@ cdef class PeriodicCubicalComplex:
launched first.
"""
assert self.pcohptr != NULL, "compute_persistence() must be called before persistence_intervals_in_dimension()"
- return np.array(self.pcohptr.intervals_in_dimension(dimension))
+ piid = np.array(self.pcohptr.intervals_in_dimension(dimension))
+ # Workaround https://github.com/GUDHI/gudhi-devel/issues/507
+ if len(piid) == 0:
+ return np.empty(shape = [0, 2])
+ return piid
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 848dc03e..e438aa66 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -12,6 +12,9 @@ from os import path
from math import isfinite
import numpy as np
from functools import lru_cache
+import warnings
+import errno
+import os
from gudhi.reader_utils import read_persistence_intervals_in_dimension
from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension
@@ -22,6 +25,7 @@ __license__ = "MIT"
_gudhi_matplotlib_use_tex = True
+
def __min_birth_max_death(persistence, band=0.0):
"""This function returns (min_birth, max_death) from the persistence.
@@ -44,20 +48,46 @@ def __min_birth_max_death(persistence, band=0.0):
min_birth = float(interval[1][0])
if band > 0.0:
max_death += band
+ # can happen if only points at inf death
+ if min_birth == max_death:
+ max_death = max_death + 1.0
return (min_birth, max_death)
def _array_handler(a):
- '''
+ """
:param a: if array, assumes it is a (n x 2) np.array and return a
persistence-compatible list (padding with 0), so that the
plot can be performed seamlessly.
- '''
- if isinstance(a[0][1], np.float64) or isinstance(a[0][1], float):
+ """
+ if isinstance(a[0][1], (np.floating, float)):
return [[0, x] for x in a]
else:
return a
+
+def _limit_to_max_intervals(persistence, max_intervals, key):
+ """This function returns truncated persistence if length is bigger than max_intervals.
+ :param persistence: Persistence intervals values list. Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
+ :param max_intervals: maximal number of intervals to display.
+ Selected intervals are those with the longest life time. Set it
+ to 0 to see all. Default value is 1000.
+ :type max_intervals: int.
+ :param key: key function for sort algorithm.
+ :type key: function or lambda.
+ """
+ if max_intervals > 0 and max_intervals < len(persistence):
+ warnings.warn(
+ "There are %s intervals given as input, whereas max_intervals is set to %s."
+ % (len(persistence), max_intervals)
+ )
+ # Sort by life time, then takes only the max_intervals elements
+ return sorted(persistence, key=key, reverse=True)[:max_intervals]
+ else:
+ return persistence
+
+
@lru_cache(maxsize=1)
def _matplotlib_can_use_tex():
"""This function returns True if matplotlib can deal with LaTeX, False otherwise.
@@ -65,17 +95,17 @@ def _matplotlib_can_use_tex():
"""
try:
from matplotlib import checkdep_usetex
+
return checkdep_usetex(True)
- except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
def plot_persistence_barcode(
persistence=[],
persistence_file="",
alpha=0.6,
- max_intervals=1000,
- max_barcodes=1000,
+ max_intervals=20000,
inf_delta=0.1,
legend=False,
colormap=None,
@@ -97,7 +127,7 @@ def plot_persistence_barcode(
:type alpha: float.
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
- to 0 to see all. Default value is 1000.
+ to 0 to see all. Default value is 20000.
:type max_intervals: int.
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x
inf_delta)` above :code:`max_death` value. A reasonable value is
@@ -119,99 +149,70 @@ def plot_persistence_barcode(
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import rc
+
if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
else:
- plt.rc('text', usetex=False)
- plt.rc('font', family='DejaVu Sans')
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
- diag = read_persistence_intervals_grouped_by_dimension(
- persistence_file=persistence_file
- )
+ diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
for key in diag.keys():
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
else:
- print("file " + persistence_file + " not found.")
- return None
-
- persistence = _array_handler(persistence)
-
- if max_barcodes != 1000:
- print("Deprecated parameter. It has been replaced by max_intervals")
- max_intervals = max_barcodes
-
- if max_intervals > 0 and max_intervals < len(persistence):
- # Sort by life time, then takes only the max_intervals elements
- persistence = sorted(
- persistence,
- key=lambda life_time: life_time[1][1] - life_time[1][0],
- reverse=True,
- )[:max_intervals]
-
- if colormap == None:
- colormap = plt.cm.Set1.colors
- if axes == None:
- fig, axes = plt.subplots(1, 1)
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
- persistence = sorted(persistence, key=lambda birth: birth[1][0])
+ try:
+ persistence = _array_handler(persistence)
+ persistence = _limit_to_max_intervals(
+ persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ (min_birth, max_death) = __min_birth_max_death(persistence)
+ persistence = sorted(persistence, key=lambda birth: birth[1][0])
+ except IndexError:
+ min_birth, max_death = 0.0, 1.0
+ pass
- (min_birth, max_death) = __min_birth_max_death(persistence)
- ind = 0
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for bar code to be more
# readable
infinity = max_death + delta
axis_start = min_birth - delta
- # Draw horizontal bars in loop
- for interval in reversed(persistence):
- if float(interval[1][1]) != float("inf"):
- # Finite death case
- axes.barh(
- ind,
- (interval[1][1] - interval[1][0]),
- height=0.8,
- left=interval[1][0],
- alpha=alpha,
- color=colormap[interval[0]],
- linewidth=0,
- )
- else:
- # Infinite death case for diagram to be nicer
- axes.barh(
- ind,
- (infinity - interval[1][0]),
- height=0.8,
- left=interval[1][0],
- alpha=alpha,
- color=colormap[interval[0]],
- linewidth=0,
- )
- ind = ind + 1
+
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
+
+ x=[birth for (dim,(birth,death)) in persistence]
+ y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence]
+ c=[colormap[dim] for (dim,(birth,death)) in persistence]
+
+ axes.barh(range(len(x)), y, left=x, alpha=alpha, color=c, linewidth=0)
if legend:
- dimensions = list(set(item[0] for item in persistence))
+ dimensions = set(item[0] for item in persistence)
axes.legend(
- handles=[
- mpatches.Patch(color=colormap[dim], label=str(dim))
- for dim in dimensions
- ],
- loc="lower right",
+ handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right",
)
axes.set_title("Persistence barcode", fontsize=fontsize)
+ axes.set_yticks([])
+ axes.invert_yaxis()
# Ends plot on infinity value and starts a little bit before min_birth
- axes.axis([axis_start, infinity, 0, ind])
+ if len(x) != 0:
+ axes.set_xlim((axis_start, infinity))
return axes
- except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
def plot_persistence_diagram(
@@ -219,14 +220,13 @@ def plot_persistence_diagram(
persistence_file="",
alpha=0.6,
band=0.0,
- max_intervals=1000,
- max_plots=1000,
+ max_intervals=1000000,
inf_delta=0.1,
legend=False,
colormap=None,
axes=None,
fontsize=16,
- greyblock=True
+ greyblock=True,
):
"""This function plots the persistence diagram from persistence values
list, a np.array of shape (N x 2) representing a diagram in a single
@@ -244,7 +244,7 @@ def plot_persistence_diagram(
:type band: float.
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
- to 0 to see all. Default value is 1000.
+ to 0 to see all. Default value is 1000000.
:type max_intervals: int.
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x
inf_delta)` above :code:`max_death` value. A reasonable value is
@@ -268,47 +268,35 @@ def plot_persistence_diagram(
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import rc
+
if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
else:
- plt.rc('text', usetex=False)
- plt.rc('font', family='DejaVu Sans')
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
- diag = read_persistence_intervals_grouped_by_dimension(
- persistence_file=persistence_file
- )
+ diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
for key in diag.keys():
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
else:
- print("file " + persistence_file + " not found.")
- return None
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
- persistence = _array_handler(persistence)
-
- if max_plots != 1000:
- print("Deprecated parameter. It has been replaced by max_intervals")
- max_intervals = max_plots
-
- if max_intervals > 0 and max_intervals < len(persistence):
- # Sort by life time, then takes only the max_intervals elements
- persistence = sorted(
- persistence,
- key=lambda life_time: life_time[1][1] - life_time[1][0],
- reverse=True,
- )[:max_intervals]
-
- if colormap == None:
- colormap = plt.cm.Set1.colors
- if axes == None:
- fig, axes = plt.subplots(1, 1)
+ try:
+ persistence = _array_handler(persistence)
+ persistence = _limit_to_max_intervals(
+ persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ min_birth, max_death = __min_birth_max_death(persistence, band)
+ except IndexError:
+ min_birth, max_death = 0.0, 1.0
+ pass
- (min_birth, max_death) = __min_birth_max_death(persistence, band)
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for diagram to be more
# readable
@@ -316,61 +304,56 @@ def plot_persistence_diagram(
axis_end = max_death + delta / 2
axis_start = min_birth - delta
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
# bootstrap band
if band > 0.0:
x = np.linspace(axis_start, infinity, 1000)
axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
# lower diag patch
if greyblock:
- axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey'))
- # Draw points in loop
- pts_at_infty = False # Records presence of pts at infty
- for interval in reversed(persistence):
- if float(interval[1][1]) != float("inf"):
- # Finite death case
- axes.scatter(
- interval[1][0],
- interval[1][1],
- alpha=alpha,
- color=colormap[interval[0]],
- )
- else:
- pts_at_infty = True
- # Infinite death case for diagram to be nicer
- axes.scatter(
- interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]
+ axes.add_patch(
+ mpatches.Polygon(
+ [[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]],
+ fill=True,
+ color="lightgrey",
)
- if pts_at_infty:
+ )
+ # line display of equation : birth = death
+ axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
+
+ x=[birth for (dim,(birth,death)) in persistence]
+ y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence]
+ c=[colormap[dim] for (dim,(birth,death)) in persistence]
+
+ axes.scatter(x,y,alpha=alpha,color=c)
+ if float("inf") in (death for (dim,(birth,death)) in persistence):
# infinity line and text
- axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha)
# Infinity label
yt = axes.get_yticks()
- yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity
+ yt = yt[np.where(yt < axis_end)] # to avoid plotting ticklabel higher than infinity
yt = np.append(yt, infinity)
ytl = ["%.3f" % e for e in yt] # to avoid float precision error
- ytl[-1] = r'$+\infty$'
+ ytl[-1] = r"$+\infty$"
axes.set_yticks(yt)
axes.set_yticklabels(ytl)
if legend:
dimensions = list(set(item[0] for item in persistence))
- axes.legend(
- handles=[
- mpatches.Patch(color=colormap[dim], label=str(dim))
- for dim in dimensions
- ]
- )
+ axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions])
axes.set_xlabel("Birth", fontsize=fontsize)
axes.set_ylabel("Death", fontsize=fontsize)
axes.set_title("Persistence diagram", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
- axes.axis([axis_start, axis_end, axis_start, infinity + delta/2])
+ axes.axis([axis_start, axis_end, axis_start, infinity + delta / 2])
return axes
- except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
def plot_persistence_density(
@@ -384,7 +367,7 @@ def plot_persistence_density(
legend=False,
axes=None,
fontsize=16,
- greyblock=False
+ greyblock=False,
):
"""This function plots the persistence density from persistence
values list, np.array of shape (N x 2) representing a diagram
@@ -444,12 +427,13 @@ def plot_persistence_density(
import matplotlib.patches as mpatches
from scipy.stats import kde
from matplotlib import rc
+
if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
else:
- plt.rc('text', usetex=False)
- plt.rc('font', family='DejaVu Sans')
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
if persistence_file != "":
if dimension is None:
@@ -460,10 +444,16 @@ def plot_persistence_density(
persistence_file=persistence_file, only_this_dim=dimension
)
else:
- print("file " + persistence_file + " not found.")
- return None
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
+
+ # default cmap value cannot be done at argument definition level as matplotlib is not yet defined.
+ if cmap is None:
+ cmap = plt.cm.hot_r
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
- if len(persistence) > 0:
+ try:
+ # if not read from file but given by an argument
persistence = _array_handler(persistence)
persistence_dim = np.array(
[
@@ -472,47 +462,54 @@ def plot_persistence_density(
if (dim_interval[0] == dimension) or (dimension is None)
]
)
-
- persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
- if max_intervals > 0 and max_intervals < len(persistence_dim):
- # Sort by life time, then takes only the max_intervals elements
+ persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
persistence_dim = np.array(
- sorted(
- persistence_dim,
- key=lambda life_time: life_time[1] - life_time[0],
- reverse=True,
- )[:max_intervals]
+ _limit_to_max_intervals(
+ persistence_dim, max_intervals, key=lambda life_time: life_time[1] - life_time[0]
+ )
)
- # Set as numpy array birth and death (remove undefined values - inf and NaN)
- birth = persistence_dim[:, 0]
- death = persistence_dim[:, 1]
-
- # default cmap value cannot be done at argument definition level as matplotlib is not yet defined.
- if cmap is None:
- cmap = plt.cm.hot_r
- if axes == None:
- fig, axes = plt.subplots(1, 1)
+ # Set as numpy array birth and death (remove undefined values - inf and NaN)
+ birth = persistence_dim[:, 0]
+ death = persistence_dim[:, 1]
+ birth_min = birth.min()
+ birth_max = birth.max()
+ death_min = death.min()
+ death_max = death.max()
+
+ # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
+ k = kde.gaussian_kde([birth, death], bw_method=bw_method)
+ xi, yi = np.mgrid[
+ birth_min : birth_max : nbins * 1j, death_min : death_max : nbins * 1j,
+ ]
+ zi = k(np.vstack([xi.flatten(), yi.flatten()]))
+ # Make the plot
+ img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto")
+ plot_success = True
+
+ # IndexError on empty diagrams, ValueError on only inf death values
+ except (IndexError, ValueError):
+ birth_min = 0.0
+ birth_max = 1.0
+ death_min = 0.0
+ death_max = 1.0
+ plot_success = False
+ pass
# line display of equation : birth = death
- x = np.linspace(death.min(), birth.max(), 1000)
+ x = np.linspace(death_min, birth_max, 1000)
axes.plot(x, x, color="k", linewidth=1.0)
- # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
- k = kde.gaussian_kde([birth, death], bw_method=bw_method)
- xi, yi = np.mgrid[
- birth.min() : birth.max() : nbins * 1j,
- death.min() : death.max() : nbins * 1j,
- ]
- zi = k(np.vstack([xi.flatten(), yi.flatten()]))
-
- # Make the plot
- img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap)
-
if greyblock:
- axes.add_patch(mpatches.Polygon([[birth.min(), birth.min()], [death.max(), birth.min()], [death.max(), death.max()]], fill=True, color='lightgrey'))
+ axes.add_patch(
+ mpatches.Polygon(
+ [[birth_min, birth_min], [death_max, birth_min], [death_max, death_max]],
+ fill=True,
+ color="lightgrey",
+ )
+ )
- if legend:
+ if plot_success and legend:
plt.colorbar(img, ax=axes)
axes.set_xlabel("Birth", fontsize=fontsize)
@@ -521,7 +518,5 @@ def plot_persistence_density(
return axes
- except ImportError:
- print(
- "This function is not available, you may be missing matplotlib and/or scipy."
- )
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 994be3b6..7dc83817 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -8,6 +8,7 @@
# - YYYY/MM Author: Description of the modification
import numpy
+import warnings
# TODO: https://github.com/facebookresearch/faiss
@@ -111,7 +112,7 @@ class KNearestNeighbors:
nargs = {
k: v for k, v in self.params.items() if k in {"p", "n_jobs", "metric_params", "algorithm", "leaf_size"}
}
- self.nn = NearestNeighbors(self.k, metric=self.metric, **nargs)
+ self.nn = NearestNeighbors(n_neighbors=self.k, metric=self.metric, **nargs)
self.nn.fit(X)
if self.params["implementation"] == "hnsw":
@@ -257,6 +258,9 @@ class KNearestNeighbors:
if ef is not None:
self.graph.set_ef(ef)
neighbors, distances = self.graph.knn_query(X, k, num_threads=self.params["num_threads"])
+ with warnings.catch_warnings():
+ if not(numpy.all(numpy.isfinite(distances))):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
# The k nearest neighbors are always sorted. I couldn't find it in the doc, but the code calls searchKnn,
# which returns a priority_queue, and then fills the return array backwards with top/pop on the queue.
if self.return_index:
@@ -290,6 +294,9 @@ class KNearestNeighbors:
if self.return_index:
if self.return_distance:
distances, neighbors = mat.Kmin_argKmin(k, dim=1)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return neighbors, distances
@@ -298,13 +305,18 @@ class KNearestNeighbors:
return neighbors
if self.return_distance:
distances = mat.Kmin(k, dim=1)
+ with warnings.catch_warnings():
+ if not(torch.isfinite(distances).all()):
+ warnings.warn("Overflow/infinite value encountered while computing 'distances'", RuntimeWarning)
if p != numpy.inf:
distances = distances ** (1.0 / p)
return distances
return None
if self.params["implementation"] == "ckdtree":
- qargs = {key: val for key, val in self.params.items() if key in {"p", "eps", "n_jobs"}}
+ qargs = {key: val for key, val in self.params.items() if key in {"p", "eps"}}
+ # SciPy renamed n_jobs to workers
+ qargs["workers"] = self.params.get("workers") or self.params.get("n_jobs") or 1
distances, neighbors = self.kdtree.query(X, k=self.k, **qargs)
if k == 1:
# SciPy decided to squeeze the last dimension for k=1
diff --git a/src/python/gudhi/representations/preprocessing.py b/src/python/gudhi/representations/preprocessing.py
index a8545349..8722e162 100644
--- a/src/python/gudhi/representations/preprocessing.py
+++ b/src/python/gudhi/representations/preprocessing.py
@@ -1,10 +1,11 @@
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
-# Author(s): Mathieu Carrière
+# Author(s): Mathieu Carrière, Vincent Rouvreau
#
# Copyright (C) 2018-2019 Inria
#
# Modification(s):
+# - 2021/10 Vincent Rouvreau: Add DimensionSelector
# - YYYY/MM Author: Description of the modification
import numpy as np
@@ -75,7 +76,7 @@ class Clamping(BaseEstimator, TransformerMixin):
Constructor for the Clamping class.
Parameters:
- limit (double): clamping value (default np.inf).
+ limit (float): clamping value (default np.inf).
"""
self.minimum = minimum
self.maximum = maximum
@@ -234,7 +235,7 @@ class ProminentPoints(BaseEstimator, TransformerMixin):
use (bool): whether to use the class or not (default False).
location (string): either "upper" or "lower" (default "upper"). Whether to keep the points that are far away ("upper") or close ("lower") to the diagonal.
num_pts (int): cardinality threshold (default 10). If location == "upper", keep the top **num_pts** points that are the farthest away from the diagonal. If location == "lower", keep the top **num_pts** points that are the closest to the diagonal.
- threshold (double): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal.
+ threshold (float): distance-to-diagonal threshold (default -1). If location == "upper", keep the points that are at least at a distance **threshold** from the diagonal. If location == "lower", keep the points that are at most at a distance **threshold** from the diagonal.
"""
self.num_pts = num_pts
self.threshold = threshold
@@ -317,7 +318,7 @@ class DiagramSelector(BaseEstimator, TransformerMixin):
Parameters:
use (bool): whether to use the class or not (default False).
- limit (double): second coordinate value that is the criterion for being an essential point (default numpy.inf).
+ limit (float): second coordinate value that is the criterion for being an essential point (default numpy.inf).
point_type (string): either "finite" or "essential". The type of the points that are going to be extracted.
"""
self.use, self.limit, self.point_type = use, limit, point_type
@@ -363,3 +364,51 @@ class DiagramSelector(BaseEstimator, TransformerMixin):
n x 2 numpy array: extracted persistence diagram.
"""
return self.fit_transform([diag])[0]
+
+
+# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
+# sequenceDiagram
+# USER->>DimensionSelector: fit_transform(<br/>[[array( Hi(X0) ), array( Hj(X0) ), ...],<br/> [array( Hi(X1) ), array( Hj(X1) ), ...],<br/> ...])
+# DimensionSelector->>thread1: _transform([array( Hi(X0) ), array( Hj(X0) )], ...)
+# DimensionSelector->>thread2: _transform([array( Hi(X1) ), array( Hj(X1) )], ...)
+# Note right of DimensionSelector: ...
+# thread1->>DimensionSelector: array( Hn(X0) )
+# thread2->>DimensionSelector: array( Hn(X1) )
+# Note right of DimensionSelector: ...
+# DimensionSelector->>USER: [array( Hn(X0) ), <br/> array( Hn(X1) ), <br/> ...]
+
+class DimensionSelector(BaseEstimator, TransformerMixin):
+ """
+ This is a class to select persistence diagrams in a specific dimension from its index.
+ """
+
+ def __init__(self, index=0):
+ """
+ Constructor for the DimensionSelector class.
+
+ Parameters:
+ index (int): The returned persistence diagrams dimension index. Default value is `0`.
+ """
+ self.index = index
+
+ def fit(self, X, Y=None):
+ """
+ Nothing to be done, but useful when included in a scikit-learn Pipeline.
+ """
+ return self
+
+ def transform(self, X, Y=None):
+ """
+ Select persistence diagrams from its dimension.
+
+ Parameters:
+ X (list of list of tuple): List of list of persistence pairs, i.e.
+ `[[array( Hi(X0) ), array( Hj(X0) ), ...], [array( Hi(X1) ), array( Hj(X1) ), ...], ...]`
+
+ Returns:
+ list of tuple:
+ Persistence diagrams in a specific dimension. i.e. if `index` was set to `m` and `Hn` is at index `m` of
+ the input, it returns `[array( Hn(X0) ), array( Hn(X1), ...]`
+ """
+
+ return [persistence[self.index] for persistence in X]
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index cdcb1fde..ce74aee5 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -1,17 +1,25 @@
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
-# Author(s): Mathieu Carrière, Martin Royer
+# Author(s): Mathieu Carrière, Martin Royer, Gard Spreemann
#
# Copyright (C) 2018-2020 Inria
#
# Modification(s):
# - 2020/06 Martin: ATOL integration
+# - 2020/12 Gard: A more flexible Betti curve class capable of computing exact curves.
+# - 2021/11 Vincent Rouvreau: factorize _automatic_sample_range
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
+from sklearn.exceptions import NotFittedError
from sklearn.preprocessing import MinMaxScaler, MaxAbsScaler
-from sklearn.neighbors import DistanceMetric
from sklearn.metrics import pairwise
+try:
+ # New location since 1.0
+ from sklearn.metrics import DistanceMetric
+except ImportError:
+ # Will be removed in 1.3
+ from sklearn.neighbors import DistanceMetric
from .preprocessing import DiagramScaler, BirthPersistenceTransform
@@ -45,10 +53,14 @@ class PersistenceImage(BaseEstimator, TransformerMixin):
y (n x 1 array): persistence diagram labels (unused).
"""
if np.isnan(np.array(self.im_range)).any():
- new_X = BirthPersistenceTransform().fit_transform(X)
- pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(new_X,y)
- [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
- self.im_range = np.where(np.isnan(np.array(self.im_range)), np.array([mx, Mx, my, My]), np.array(self.im_range))
+ try:
+ new_X = BirthPersistenceTransform().fit_transform(X)
+ pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(new_X,y)
+ [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
+ self.im_range = np.where(np.isnan(np.array(self.im_range)), np.array([mx, Mx, my, My]), np.array(self.im_range))
+ except ValueError:
+ # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
+ pass
return self
def transform(self, X):
@@ -78,7 +90,7 @@ class PersistenceImage(BaseEstimator, TransformerMixin):
Xfit.append(image.flatten()[np.newaxis,:])
- Xfit = np.concatenate(Xfit,0)
+ Xfit = np.concatenate(Xfit, 0)
return Xfit
@@ -94,11 +106,57 @@ class PersistenceImage(BaseEstimator, TransformerMixin):
"""
return self.fit_transform([diag])[0,:]
+def _automatic_sample_range(sample_range, X):
+ """
+ Compute and returns sample range from the persistence diagrams if one of the sample_range values is numpy.nan.
+
+ Parameters:
+ sample_range (a numpy array of 2 float): minimum and maximum of all piecewise-linear function domains, of
+ the form [x_min, x_max].
+ X (list of n x 2 numpy arrays): input persistence diagrams.
+ y (n x 1 array): persistence diagram labels (unused).
+ """
+ nan_in_range = np.isnan(sample_range)
+ if nan_in_range.any():
+ try:
+ pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X)
+ [mx,my] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]]
+ [Mx,My] = [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
+ return np.where(nan_in_range, np.array([mx, My]), sample_range)
+ except ValueError:
+ # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
+ pass
+ return sample_range
+
+
+def _trim_endpoints(x, are_endpoints_nan):
+ if are_endpoints_nan[0]:
+ x = x[1:]
+ if are_endpoints_nan[1]:
+ x = x[:-1]
+ return x
+
+
+def _grid_from_sample_range(self, X):
+ sample_range = np.array(self.sample_range)
+ self.nan_in_range = np.isnan(sample_range)
+ self.new_resolution = self.resolution
+ if not self.keep_endpoints:
+ self.new_resolution += self.nan_in_range.sum()
+ self.sample_range_fixed = _automatic_sample_range(sample_range, X)
+ self.grid_ = np.linspace(self.sample_range_fixed[0], self.sample_range_fixed[1], self.new_resolution)
+ if not self.keep_endpoints:
+ self.grid_ = _trim_endpoints(self.grid_, self.nan_in_range)
+
+
class Landscape(BaseEstimator, TransformerMixin):
"""
This is a class for computing persistence landscapes from a list of persistence diagrams. A persistence landscape is a collection of 1D piecewise-linear functions computed from the rank function associated to the persistence diagram. These piecewise-linear functions are then sampled evenly on a given range and the corresponding vectors of samples are concatenated and returned. See http://jmlr.org/papers/v16/bubenik15a.html for more details.
+
+ Attributes:
+ grid_ (1d array): The grid on which the landscapes are computed.
"""
- def __init__(self, num_landscapes=5, resolution=100, sample_range=[np.nan, np.nan]):
+ def __init__(self, num_landscapes=5, resolution=100, sample_range=[np.nan, np.nan], *, keep_endpoints=False):
"""
Constructor for the Landscape class.
@@ -106,10 +164,10 @@ class Landscape(BaseEstimator, TransformerMixin):
num_landscapes (int): number of piecewise-linear functions to output (default 5).
resolution (int): number of sample for all piecewise-linear functions (default 100).
sample_range ([double, double]): minimum and maximum of all piecewise-linear function domains, of the form [x_min, x_max] (default [numpy.nan, numpy.nan]). It is the interval on which samples will be drawn evenly. If one of the values is numpy.nan, it can be computed from the persistence diagrams with the fit() method.
+ keep_endpoints (bool): when computing `sample_range`, use the exact extremities (where the value is always 0). This is mostly useful for plotting, the default is to use a slightly smaller range.
"""
self.num_landscapes, self.resolution, self.sample_range = num_landscapes, resolution, sample_range
- self.nan_in_range = np.isnan(np.array(self.sample_range))
- self.new_resolution = self.resolution + self.nan_in_range.sum()
+ self.keep_endpoints = keep_endpoints
def fit(self, X, y=None):
"""
@@ -119,10 +177,7 @@ class Landscape(BaseEstimator, TransformerMixin):
X (list of n x 2 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
- if self.nan_in_range.any():
- pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y)
- [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
- self.sample_range = np.where(self.nan_in_range, np.array([mx, My]), np.array(self.sample_range))
+ _grid_from_sample_range(self, X)
return self
def transform(self, X):
@@ -135,53 +190,26 @@ class Landscape(BaseEstimator, TransformerMixin):
Returns:
numpy array with shape (number of diagrams) x (number of samples = **num_landscapes** x **resolution**): output persistence landscapes.
"""
- num_diag, Xfit = len(X), []
- x_values = np.linspace(self.sample_range[0], self.sample_range[1], self.new_resolution)
- step_x = x_values[1] - x_values[0]
-
- for i in range(num_diag):
-
- diagram, num_pts_in_diag = X[i], X[i].shape[0]
-
- ls = np.zeros([self.num_landscapes, self.new_resolution])
-
- events = []
- for j in range(self.new_resolution):
- events.append([])
-
- for j in range(num_pts_in_diag):
- [px,py] = diagram[j,:2]
- min_idx = np.clip(np.ceil((px - self.sample_range[0]) / step_x).astype(int), 0, self.new_resolution)
- mid_idx = np.clip(np.ceil((0.5*(py+px) - self.sample_range[0]) / step_x).astype(int), 0, self.new_resolution)
- max_idx = np.clip(np.ceil((py - self.sample_range[0]) / step_x).astype(int), 0, self.new_resolution)
-
- if min_idx < self.new_resolution and max_idx > 0:
-
- landscape_value = self.sample_range[0] + min_idx * step_x - px
- for k in range(min_idx, mid_idx):
- events[k].append(landscape_value)
- landscape_value += step_x
-
- landscape_value = py - self.sample_range[0] - mid_idx * step_x
- for k in range(mid_idx, max_idx):
- events[k].append(landscape_value)
- landscape_value -= step_x
-
- for j in range(self.new_resolution):
- events[j].sort(reverse=True)
- for k in range( min(self.num_landscapes, len(events[j])) ):
- ls[k,j] = events[j][k]
-
- if self.nan_in_range[0]:
- ls = ls[:,1:]
- if self.nan_in_range[1]:
- ls = ls[:,:-1]
- ls = np.sqrt(2)*np.reshape(ls,[1,-1])
- Xfit.append(ls)
- Xfit = np.concatenate(Xfit,0)
-
- return Xfit
+ Xfit = []
+ x_values = self.grid_
+ for diag in X:
+ midpoints, heights = (diag[:, 0] + diag[:, 1]) / 2., (diag[:, 1] - diag[:, 0]) / 2.
+ tent_functions = np.maximum(heights[None, :] - np.abs(x_values[:, None] - midpoints[None, :]), 0)
+ n_points = diag.shape[0]
+ # Complete the array with zeros to get the right number of landscapes
+ if self.num_landscapes > n_points:
+ tent_functions = np.concatenate(
+ [tent_functions, np.zeros((tent_functions.shape[0], self.num_landscapes-n_points))],
+ axis=1
+ )
+ tent_functions.partition(tent_functions.shape[1]-self.num_landscapes, axis=1)
+ landscapes = np.sort(tent_functions[:, -self.num_landscapes:], axis=1)[:, ::-1].T
+
+ landscapes = np.sqrt(2) * np.ravel(landscapes)
+ Xfit.append(landscapes)
+
+ return np.stack(Xfit, axis=0)
def __call__(self, diag):
"""
@@ -193,13 +221,16 @@ class Landscape(BaseEstimator, TransformerMixin):
Returns:
numpy array with shape (number of samples = **num_landscapes** x **resolution**): output persistence landscape.
"""
- return self.fit_transform([diag])[0,:]
+ return self.fit_transform([diag])[0, :]
class Silhouette(BaseEstimator, TransformerMixin):
"""
This is a class for computing persistence silhouettes from a list of persistence diagrams. A persistence silhouette is computed by taking a weighted average of the collection of 1D piecewise-linear functions given by the persistence landscapes, and then by evenly sampling this average on a given range. Finally, the corresponding vector of samples is returned. See https://arxiv.org/abs/1312.0308 for more details.
+
+ Attributes:
+ grid_ (1d array): The grid on which the silhouette is computed.
"""
- def __init__(self, weight=lambda x: 1, resolution=100, sample_range=[np.nan, np.nan]):
+ def __init__(self, weight=lambda x: 1, resolution=100, sample_range=[np.nan, np.nan], *, keep_endpoints=False):
"""
Constructor for the Silhouette class.
@@ -207,8 +238,10 @@ class Silhouette(BaseEstimator, TransformerMixin):
weight (function): weight function for the persistence diagram points (default constant function, ie lambda x: 1). This function must be defined on 2D points, ie on lists or numpy arrays of the form [p_x,p_y].
resolution (int): number of samples for the weighted average (default 100).
sample_range ([double, double]): minimum and maximum for the weighted average domain, of the form [x_min, x_max] (default [numpy.nan, numpy.nan]). It is the interval on which samples will be drawn evenly. If one of the values is numpy.nan, it can be computed from the persistence diagrams with the fit() method.
+ keep_endpoints (bool): when computing `sample_range`, use the exact extremities (where the value is always 0). This is mostly useful for plotting, the default is to use a slightly smaller range.
"""
self.weight, self.resolution, self.sample_range = weight, resolution, sample_range
+ self.keep_endpoints = keep_endpoints
def fit(self, X, y=None):
"""
@@ -218,10 +251,7 @@ class Silhouette(BaseEstimator, TransformerMixin):
X (list of n x 2 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
- if np.isnan(np.array(self.sample_range)).any():
- pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y)
- [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
- self.sample_range = np.where(np.isnan(np.array(self.sample_range)), np.array([mx, My]), np.array(self.sample_range))
+ _grid_from_sample_range(self, X)
return self
def transform(self, X):
@@ -234,44 +264,19 @@ class Silhouette(BaseEstimator, TransformerMixin):
Returns:
numpy array with shape (number of diagrams) x (**resolution**): output persistence silhouettes.
"""
- num_diag, Xfit = len(X), []
- x_values = np.linspace(self.sample_range[0], self.sample_range[1], self.resolution)
- step_x = x_values[1] - x_values[0]
-
- for i in range(num_diag):
-
- diagram, num_pts_in_diag = X[i], X[i].shape[0]
+ Xfit = []
+ x_values = self.grid_
- sh, weights = np.zeros(self.resolution), np.zeros(num_pts_in_diag)
- for j in range(num_pts_in_diag):
- weights[j] = self.weight(diagram[j,:])
+ for diag in X:
+ midpoints, heights = (diag[:, 0] + diag[:, 1]) / 2., (diag[:, 1] - diag[:, 0]) / 2.
+ weights = np.array([self.weight(pt) for pt in diag])
total_weight = np.sum(weights)
- for j in range(num_pts_in_diag):
-
- [px,py] = diagram[j,:2]
- weight = weights[j] / total_weight
- min_idx = np.clip(np.ceil((px - self.sample_range[0]) / step_x).astype(int), 0, self.resolution)
- mid_idx = np.clip(np.ceil((0.5*(py+px) - self.sample_range[0]) / step_x).astype(int), 0, self.resolution)
- max_idx = np.clip(np.ceil((py - self.sample_range[0]) / step_x).astype(int), 0, self.resolution)
-
- if min_idx < self.resolution and max_idx > 0:
-
- silhouette_value = self.sample_range[0] + min_idx * step_x - px
- for k in range(min_idx, mid_idx):
- sh[k] += weight * silhouette_value
- silhouette_value += step_x
-
- silhouette_value = py - self.sample_range[0] - mid_idx * step_x
- for k in range(mid_idx, max_idx):
- sh[k] += weight * silhouette_value
- silhouette_value -= step_x
-
- Xfit.append(np.reshape(np.sqrt(2) * sh, [1,-1]))
+ tent_functions = np.maximum(heights[None, :] - np.abs(x_values[:, None] - midpoints[None, :]), 0)
+ silhouette = np.sum(weights[None, :] / total_weight * tent_functions, axis=1)
+ Xfit.append(silhouette * np.sqrt(2))
- Xfit = np.concatenate(Xfit, 0)
-
- return Xfit
+ return np.stack(Xfit, axis=0)
def __call__(self, diag):
"""
@@ -285,76 +290,174 @@ class Silhouette(BaseEstimator, TransformerMixin):
"""
return self.fit_transform([diag])[0,:]
+
class BettiCurve(BaseEstimator, TransformerMixin):
"""
- This is a class for computing Betti curves from a list of persistence diagrams. A Betti curve is a 1D piecewise-constant function obtained from the rank function. It is sampled evenly on a given range and the vector of samples is returned. See https://www.researchgate.net/publication/316604237_Time_Series_Classification_via_Topological_Data_Analysis for more details.
+ Compute Betti curves from persistence diagrams. There are several modes of operation: with a given resolution (with or without a sample_range), with a predefined grid, and with none of the previous. With a predefined grid, the class computes the Betti numbers at those grid points. Without a predefined grid, if the resolution is set to None, it can be fit to a list of persistence diagrams and produce a grid that consists of (at least) the filtration values at which at least one of those persistence diagrams changes Betti numbers, and then compute the Betti numbers at those grid points. In the latter mode, the exact Betti curve is computed for the entire real line. Otherwise, if the resolution is given, the Betti curve is obtained by sampling evenly using either the given sample_range or based on the persistence diagrams.
+
+ Examples
+ --------
+ If pd is a persistence diagram and xs is a nonempty grid of finite values such that xs[0] >= pd.min(), then the results of:
+
+ >>> bc = BettiCurve(predefined_grid=xs) # doctest: +SKIP
+ >>> result = bc(pd) # doctest: +SKIP
+
+ and
+
+ >>> from scipy.interpolate import interp1d # doctest: +SKIP
+ >>> bc = BettiCurve(resolution=None, predefined_grid=None) # doctest: +SKIP
+ >>> bettis = bc.fit_transform([pd]) # doctest: +SKIP
+ >>> interp = interp1d(bc.grid_, bettis[0, :], kind="previous", fill_value="extrapolate") # doctest: +SKIP
+ >>> result = np.array(interp(xs), dtype=int) # doctest: +SKIP
+
+ are the same.
+
+ Attributes
+ ----------
+ grid_ : 1d array
+ The grid on which the Betti numbers are computed. If predefined_grid was specified, `grid_` will always be that grid, independently of data. If not and resolution is None, the grid is fitted to capture all filtration values at which the Betti numbers change.
"""
- def __init__(self, resolution=100, sample_range=[np.nan, np.nan]):
+
+ def __init__(self, resolution=100, sample_range=[np.nan, np.nan], predefined_grid=None, *, keep_endpoints=False):
"""
Constructor for the BettiCurve class.
Parameters:
- resolution (int): number of sample for the piecewise-constant function (default 100).
+ resolution (int): number of samples for the piecewise-constant function (default 100), or None for the exact curve.
sample_range ([double, double]): minimum and maximum of the piecewise-constant function domain, of the form [x_min, x_max] (default [numpy.nan, numpy.nan]). It is the interval on which samples will be drawn evenly. If one of the values is numpy.nan, it can be computed from the persistence diagrams with the fit() method.
+ predefined_grid (1d array or None, default=None): Predefined filtration grid points at which to compute the Betti curves. Must be strictly ordered. Infinities are ok. If None (default), and resolution is given, the grid will be uniform from x_min to x_max in 'resolution' steps, otherwise a grid will be computed that captures all changes in Betti numbers in the provided data.
+ keep_endpoints (bool): when computing `sample_range` (fixed `resolution`, no `predefined_grid`), use the exact extremities. This is mostly useful for plotting, the default is to use a slightly smaller range.
"""
- self.resolution, self.sample_range = resolution, sample_range
- def fit(self, X, y=None):
+ if (predefined_grid is not None) and (not isinstance(predefined_grid, np.ndarray)):
+ raise ValueError("Expected predefined_grid as array or None.")
+
+ self.predefined_grid = predefined_grid
+ self.resolution = resolution
+ self.sample_range = sample_range
+ self.keep_endpoints = keep_endpoints
+
+ def is_fitted(self):
+ return hasattr(self, "grid_")
+
+ def fit(self, X, y = None):
"""
- Fit the BettiCurve class on a list of persistence diagrams: if any of the values in **sample_range** is numpy.nan, replace it with the corresponding value computed on the given list of persistence diagrams.
+ Fit the BettiCurve class on a list of persistence diagrams: if any of the values in **sample_range** is numpy.nan, replace it with the corresponding value computed on the given list of persistence diagrams. When no predefined grid is provided and resolution set to None, compute a filtration grid that captures all changes in Betti numbers for all the given persistence diagrams.
Parameters:
- X (list of n x 2 numpy arrays): input persistence diagrams.
- y (n x 1 array): persistence diagram labels (unused).
+ X (list of 2d arrays): Persistence diagrams.
+ y (None): Ignored.
"""
- if np.isnan(np.array(self.sample_range)).any():
- pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y)
- [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
- self.sample_range = np.where(np.isnan(np.array(self.sample_range)), np.array([mx, My]), np.array(self.sample_range))
+
+ if self.predefined_grid is None:
+ if self.resolution is None: # Flexible/exact version
+ events = np.unique(np.concatenate([pd.flatten() for pd in X] + [[-np.inf]], axis=0))
+ self.grid_ = np.array(events)
+ else:
+ _grid_from_sample_range(self, X)
+ else:
+ self.grid_ = self.predefined_grid # Get the predefined grid from user
+
return self
def transform(self, X):
"""
- Compute the Betti curve for each persistence diagram individually and concatenate the results.
+ Compute Betti curves.
Parameters:
- X (list of n x 2 numpy arrays): input persistence diagrams.
-
+ X (list of 2d arrays): Persistence diagrams.
+
Returns:
- numpy array with shape (number of diagrams) x (**resolution**): output Betti curves.
+ `len(X).len(self.grid_)` array of ints: Betti numbers of the given persistence diagrams at the grid points given in `self.grid_`
"""
- Xfit = []
- x_values = np.linspace(self.sample_range[0], self.sample_range[1], self.resolution)
- step_x = x_values[1] - x_values[0]
- for diagram in X:
- diagram_int = np.clip(np.ceil((diagram[:,:2] - self.sample_range[0]) / step_x), 0, self.resolution).astype(int)
- bc = np.zeros(self.resolution)
- for interval in diagram_int:
- bc[interval[0]:interval[1]] += 1
- Xfit.append(np.reshape(bc,[1,-1]))
+ if not self.is_fitted():
+ raise NotFittedError("Not fitted.")
- Xfit = np.concatenate(Xfit, 0)
+ if not X:
+ X = [np.zeros((0, 2))]
+
+ N = len(X)
- return Xfit
+ events = np.concatenate([pd.flatten(order="F") for pd in X], axis=0)
+ sorting = np.argsort(events)
+ offsets = np.zeros(1 + N, dtype=int)
+ for i in range(0, N):
+ offsets[i+1] = offsets[i] + 2*X[i].shape[0]
+ starts = offsets[0:N]
+ ends = offsets[1:N + 1] - 1
- def __call__(self, diag):
+ bettis = [[0] for i in range(0, N)]
+
+ i = 0
+ for x in self.grid_:
+ while i < len(sorting) and events[sorting[i]] <= x:
+ j = np.searchsorted(ends, sorting[i])
+ delta = 1 if sorting[i] - starts[j] < len(X[j]) else -1
+ bettis[j][-1] += delta
+ i += 1
+ for k in range(0, N):
+ bettis[k].append(bettis[k][-1])
+
+ return np.array(bettis, dtype=int)[:, 0:-1]
+
+ def fit_transform(self, X):
+ """
+ The result is the same as fit(X) followed by transform(X), but potentially faster.
"""
- Apply BettiCurve on a single persistence diagram and outputs the result.
- Parameters:
- diag (n x 2 numpy array): input persistence diagram.
+ if self.predefined_grid is None and self.resolution is None:
+ if not X:
+ X = [np.zeros((0, 2))]
- Returns:
- numpy array with shape (**resolution**): output Betti curve.
+ N = len(X)
+
+ events = np.concatenate([pd.flatten(order="F") for pd in X], axis=0)
+ sorting = np.argsort(events)
+ offsets = np.zeros(1 + N, dtype=int)
+ for i in range(0, N):
+ offsets[i+1] = offsets[i] + 2*X[i].shape[0]
+ starts = offsets[0:N]
+ ends = offsets[1:N + 1] - 1
+
+ xs = [-np.inf]
+ bettis = [[0] for i in range(0, N)]
+
+ for i in sorting:
+ j = np.searchsorted(ends, i)
+ delta = 1 if i - starts[j] < len(X[j]) else -1
+ if events[i] == xs[-1]:
+ bettis[j][-1] += delta
+ else:
+ xs.append(events[i])
+ for k in range(0, j):
+ bettis[k].append(bettis[k][-1])
+ bettis[j].append(bettis[j][-1] + delta)
+ for k in range(j+1, N):
+ bettis[k].append(bettis[k][-1])
+
+ self.grid_ = np.array(xs)
+ return np.array(bettis, dtype=int)
+
+ else:
+ return self.fit(X).transform(X)
+
+ def __call__(self, diag):
"""
- return self.fit_transform([diag])[0,:]
+ Shorthand for transform on a single persistence diagram.
+ """
+ return self.fit_transform([diag])[0, :]
+
+
class Entropy(BaseEstimator, TransformerMixin):
"""
This is a class for computing persistence entropy. Persistence entropy is a statistic for persistence diagrams inspired from Shannon entropy. This statistic can also be used to compute a feature vector, called the entropy summary function. See https://arxiv.org/pdf/1803.08304.pdf for more details. Note that a previous implementation was contributed by Manuel Soriano-Trigueros.
+
+ Attributes:
+ grid_ (1d array): In vector mode, the grid on which the entropy summary function is computed.
"""
- def __init__(self, mode="scalar", normalized=True, resolution=100, sample_range=[np.nan, np.nan]):
+ def __init__(self, mode="scalar", normalized=True, resolution=100, sample_range=[np.nan, np.nan], *, keep_endpoints=False):
"""
Constructor for the Entropy class.
@@ -363,8 +466,10 @@ class Entropy(BaseEstimator, TransformerMixin):
normalized (bool): whether to normalize the entropy summary function (default True). Used only if **mode** = "vector".
resolution (int): number of sample for the entropy summary function (default 100). Used only if **mode** = "vector".
sample_range ([double, double]): minimum and maximum of the entropy summary function domain, of the form [x_min, x_max] (default [numpy.nan, numpy.nan]). It is the interval on which samples will be drawn evenly. If one of the values is numpy.nan, it can be computed from the persistence diagrams with the fit() method. Used only if **mode** = "vector".
+ keep_endpoints (bool): when computing `sample_range`, use the exact extremities. This is mostly useful for plotting, the default is to use a slightly smaller range.
"""
self.mode, self.normalized, self.resolution, self.sample_range = mode, normalized, resolution, sample_range
+ self.keep_endpoints = keep_endpoints
def fit(self, X, y=None):
"""
@@ -374,10 +479,9 @@ class Entropy(BaseEstimator, TransformerMixin):
X (list of n x 2 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
- if np.isnan(np.array(self.sample_range)).any():
- pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y)
- [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
- self.sample_range = np.where(np.isnan(np.array(self.sample_range)), np.array([mx, My]), np.array(self.sample_range))
+ if self.mode == "vector":
+ _grid_from_sample_range(self, X)
+ self.step_ = self.grid_[1] - self.grid_[0]
return self
def transform(self, X):
@@ -391,33 +495,28 @@ class Entropy(BaseEstimator, TransformerMixin):
numpy array with shape (number of diagrams) x (1 if **mode** = "scalar" else **resolution**): output entropy.
"""
num_diag, Xfit = len(X), []
- x_values = np.linspace(self.sample_range[0], self.sample_range[1], self.resolution)
- step_x = x_values[1] - x_values[0]
new_X = BirthPersistenceTransform().fit_transform(X)
for i in range(num_diag):
-
- orig_diagram, diagram, num_pts_in_diag = X[i], new_X[i], X[i].shape[0]
- new_diagram = DiagramScaler(use=True, scalers=[([1], MaxAbsScaler())]).fit_transform([diagram])[0]
-
+ orig_diagram, new_diagram, num_pts_in_diag = X[i], new_X[i], X[i].shape[0]
+
+ p = new_diagram[:,1]
+ p = p/np.sum(p)
if self.mode == "scalar":
- ent = - np.sum( np.multiply(new_diagram[:,1], np.log(new_diagram[:,1])) )
+ ent = -np.dot(p, np.log(p))
Xfit.append(np.array([[ent]]))
-
else:
ent = np.zeros(self.resolution)
for j in range(num_pts_in_diag):
[px,py] = orig_diagram[j,:2]
- min_idx = np.clip(np.ceil((px - self.sample_range[0]) / step_x).astype(int), 0, self.resolution)
- max_idx = np.clip(np.ceil((py - self.sample_range[0]) / step_x).astype(int), 0, self.resolution)
- for k in range(min_idx, max_idx):
- ent[k] += (-1) * new_diagram[j,1] * np.log(new_diagram[j,1])
- if self.normalized:
- ent = ent / np.linalg.norm(ent, ord=1)
- Xfit.append(np.reshape(ent,[1,-1]))
-
- Xfit = np.concatenate(Xfit, 0)
-
+ min_idx = np.clip(np.ceil((px - self.sample_range_fixed[0]) / self.step_).astype(int), 0, self.resolution)
+ max_idx = np.clip(np.ceil((py - self.sample_range_fixed[0]) / self.step_).astype(int), 0, self.resolution)
+ ent[min_idx:max_idx]-=p[j]*np.log(p[j])
+ if self.normalized:
+ ent = ent / np.linalg.norm(ent, ord=1)
+ Xfit.append(np.reshape(ent,[1,-1]))
+
+ Xfit = np.concatenate(Xfit, axis=0)
return Xfit
def __call__(self, diag):
@@ -478,7 +577,13 @@ class TopologicalVector(BaseEstimator, TransformerMixin):
diagram, num_pts_in_diag = X[i], X[i].shape[0]
pers = 0.5 * (diagram[:,1]-diagram[:,0])
min_pers = np.minimum(pers,np.transpose(pers))
- distances = DistanceMetric.get_metric("chebyshev").pairwise(diagram)
+ # Works fine with sklearn 1.0, but an ValueError exception is thrown on past versions
+ try:
+ distances = DistanceMetric.get_metric("chebyshev").pairwise(diagram)
+ except ValueError:
+ # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
+ assert len(diagram) == 0
+ distances = np.empty(shape = [0, 0])
vect = np.flip(np.sort(np.triu(np.minimum(distances, min_pers)), axis=None), 0)
dim = min(len(vect), thresh)
Xfit[i, :dim] = vect[:dim]
@@ -606,17 +711,18 @@ class Atol(BaseEstimator, TransformerMixin):
>>> c = np.array([[3, 2, -1], [1, 2, -1]])
>>> atol_vectoriser = Atol(quantiser=KMeans(n_clusters=2, random_state=202006))
>>> atol_vectoriser.fit(X=[a, b, c]).centers
- array([[ 2. , 0.66666667, 3.33333333],
- [ 2.6 , 2.8 , -0.4 ]])
+ array([[ 2.6 , 2.8 , -0.4 ],
+ [ 2. , 0.66666667, 3.33333333]])
>>> atol_vectoriser(a)
- array([1.18168665, 0.42375966])
+ array([0.42375966, 1.18168665])
>>> atol_vectoriser(c)
- array([0.02062512, 1.25157463])
+ array([1.25157463, 0.02062512])
>>> atol_vectoriser.transform(X=[a, b, c])
- array([[1.18168665, 0.42375966],
- [0.29861028, 1.06330156],
- [0.02062512, 1.25157463]])
+ array([[0.42375966, 1.18168665],
+ [1.06330156, 0.29861028],
+ [1.25157463, 0.02062512]])
"""
+ # Note the example above must be up to date with the one in tests called test_atol_doc
def __init__(self, quantiser, weighting_method="cloud", contrast="gaussian"):
"""
Constructor for the Atol measure vectorisation class.
@@ -665,6 +771,8 @@ class Atol(BaseEstimator, TransformerMixin):
measures_concat = np.concatenate(X)
self.quantiser.fit(X=measures_concat, sample_weight=sample_weight)
self.centers = self.quantiser.cluster_centers_
+ # Hack, but some people are unhappy if the order depends on the version of sklearn
+ self.centers = self.centers[np.lexsort(self.centers.T)]
if self.quantiser.n_clusters == 1:
dist_centers = pairwise.pairwise_distances(measures_concat)
np.fill_diagonal(dist_centers, 0)
diff --git a/src/python/gudhi/rips_complex.pyx b/src/python/gudhi/rips_complex.pyx
index 72e82c79..d748f91e 100644
--- a/src/python/gudhi/rips_complex.pyx
+++ b/src/python/gudhi/rips_complex.pyx
@@ -41,31 +41,30 @@ cdef class RipsComplex:
cdef Rips_complex_interface thisref
# Fake constructor that does nothing but documenting the constructor
- def __init__(self, points=None, distance_matrix=None,
+ def __init__(self, *, points=None, distance_matrix=None,
max_edge_length=float('inf'), sparse=None):
"""RipsComplex constructor.
- :param max_edge_length: Rips value.
- :type max_edge_length: float
-
:param points: A list of points in d-Dimension.
- :type points: list of list of double
+ :type points: List[List[float]]
Or
:param distance_matrix: A distance matrix (full square or lower
triangular).
- :type points: list of list of double
+ :type distance_matrix: List[List[float]]
And in both cases
+ :param max_edge_length: Rips value.
+ :type max_edge_length: float
:param sparse: If this is not None, it switches to building a sparse
Rips and represents the approximation parameter epsilon.
:type sparse: float
"""
# The real cython constructor
- def __cinit__(self, points=None, distance_matrix=None,
+ def __cinit__(self, *, points=None, distance_matrix=None,
max_edge_length=float('inf'), sparse=None):
if sparse is not None:
if distance_matrix is not None:
@@ -89,10 +88,10 @@ cdef class RipsComplex:
def create_simplex_tree(self, max_dimension=1):
"""
- :param max_dimension: graph expansion for rips until this given maximal
+ :param max_dimension: graph expansion for Rips until this given maximal
dimension.
:type max_dimension: int
- :returns: A simplex tree created from the Delaunay Triangulation.
+ :returns: A simplex tree encoding the Vietoris–Rips filtration.
:rtype: SimplexTree
"""
stree = SimplexTree()
diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd
index 000323af..5309c6fa 100644
--- a/src/python/gudhi/simplex_tree.pxd
+++ b/src/python/gudhi/simplex_tree.pxd
@@ -44,7 +44,8 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi":
cdef cppclass Simplex_tree_interface_full_featured "Gudhi::Simplex_tree_interface<Gudhi::Simplex_tree_options_full_featured>":
- Simplex_tree() nogil
+ Simplex_tree_interface_full_featured() nogil
+ Simplex_tree_interface_full_featured(Simplex_tree_interface_full_featured&) nogil
double simplex_filtration(vector[int] simplex) nogil
void assign_simplex_filtration(vector[int] simplex, double filtration) nogil
void initialize_filtration() nogil
@@ -55,6 +56,8 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi":
int upper_bound_dimension() nogil
bool find_simplex(vector[int] simplex) nogil
bool insert(vector[int] simplex, double filtration) nogil
+ void insert_matrix(double* filtrations, int n, int stride0, int stride1, double max_filtration) nogil except +
+ void insert_batch_vertices(vector[int] v, double f) nogil except +
vector[pair[vector[int], double]] get_star(vector[int] simplex) nogil
vector[pair[vector[int], double]] get_cofaces(vector[int] simplex, int dimension) nogil
void expansion(int max_dim) nogil except +
@@ -62,9 +65,9 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi":
bool prune_above_filtration(double filtration) nogil
bool make_filtration_non_decreasing() nogil
void compute_extended_filtration() nogil
- vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(vector[pair[int, pair[double, double]]] dgm, double min_persistence) nogil
Simplex_tree_interface_full_featured* collapse_edges(int nb_collapse_iteration) nogil except +
void reset_filtration(double filtration, int dimension) nogil
+ bint operator==(Simplex_tree_interface_full_featured) nogil
# Iterators over Simplex tree
pair[vector[int], double] get_simplex_and_filtration(Simplex_tree_simplex_handle f_simplex) nogil
Simplex_tree_simplices_iterator get_simplices_iterator_begin() nogil
@@ -74,11 +77,14 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi":
Simplex_tree_skeleton_iterator get_skeleton_iterator_begin(int dimension) nogil
Simplex_tree_skeleton_iterator get_skeleton_iterator_end(int dimension) nogil
pair[Simplex_tree_boundary_iterator, Simplex_tree_boundary_iterator] get_boundary_iterators(vector[int] simplex) nogil except +
+ # Expansion with blockers
+ ctypedef bool (*blocker_func_t)(vector[int], void *user_data)
+ void expansion_with_blockers_callback(int dimension, blocker_func_t user_func, void *user_data)
cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi":
- cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface<Gudhi::Simplex_tree<Gudhi::Simplex_tree_options_full_featured>>":
+ cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface<Gudhi::Simplex_tree_interface<Gudhi::Simplex_tree_options_full_featured>>":
Simplex_tree_persistence_interface(Simplex_tree_interface_full_featured * st, bool persistence_dim_max) nogil
- void compute_persistence(int homology_coeff_field, double min_persistence) nogil
+ void compute_persistence(int homology_coeff_field, double min_persistence) nogil except +
vector[pair[int, pair[double, double]]] get_persistence() nogil
vector[int] betti_numbers() nogil
vector[int] persistent_betti_numbers(double from_value, double to_value) nogil
@@ -87,3 +93,4 @@ cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi":
vector[pair[vector[int], vector[int]]] persistence_pairs() nogil
pair[vector[vector[int]], vector[vector[int]]] lower_star_generators() nogil
pair[vector[vector[int]], vector[vector[int]]] flag_generators() nogil
+ vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(double min_persistence) nogil
diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx
index d7991417..4cf176f5 100644
--- a/src/python/gudhi/simplex_tree.pyx
+++ b/src/python/gudhi/simplex_tree.pyx
@@ -8,15 +8,27 @@
# - YYYY/MM Author: Description of the modification
from cython.operator import dereference, preincrement
-from libc.stdint cimport intptr_t
-import numpy
-from numpy import array as np_array
-cimport simplex_tree
+from libc.stdint cimport intptr_t, int32_t, int64_t
+import numpy as np
+cimport gudhi.simplex_tree
+cimport cython
+from numpy.math cimport INFINITY
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2016 Inria"
__license__ = "MIT"
+ctypedef fused some_int:
+ int32_t
+ int64_t
+
+ctypedef fused some_float:
+ float
+ double
+
+cdef bool callback(vector[int] simplex, void *blocker_func):
+ return (<object>blocker_func)(simplex)
+
# SimplexTree python interface
cdef class SimplexTree:
"""The simplex tree is an efficient and flexible data structure for
@@ -39,13 +51,29 @@ cdef class SimplexTree:
cdef Simplex_tree_persistence_interface * pcohptr
# Fake constructor that does nothing but documenting the constructor
- def __init__(self):
+ def __init__(self, other = None):
"""SimplexTree constructor.
+
+ :param other: If `other` is `None` (default value), an empty `SimplexTree` is created.
+ If `other` is a `SimplexTree`, the `SimplexTree` is constructed from a deep copy of `other`.
+ :type other: SimplexTree (Optional)
+ :returns: An empty or a copy simplex tree.
+ :rtype: SimplexTree
+
+ :raises TypeError: In case `other` is neither `None`, nor a `SimplexTree`.
+ :note: If the `SimplexTree` is a copy, the persistence information is not copied. If you need it in the clone,
+ you have to call :func:`compute_persistence` on it even if you had already computed it in the original.
"""
# The real cython constructor
- def __cinit__(self):
- self.thisptr = <intptr_t>(new Simplex_tree_interface_full_featured())
+ def __cinit__(self, other = None):
+ if other:
+ if isinstance(other, SimplexTree):
+ self.thisptr = _get_copy_intptr(other)
+ else:
+ raise TypeError("`other` argument requires to be of type `SimplexTree`, or `None`.")
+ else:
+ self.thisptr = <intptr_t>(new Simplex_tree_interface_full_featured())
def __dealloc__(self):
cdef Simplex_tree_interface_full_featured* ptr = self.get_ptr()
@@ -64,6 +92,21 @@ cdef class SimplexTree:
"""
return self.pcohptr != NULL
+ def copy(self):
+ """
+ :returns: A simplex tree that is a deep copy of itself.
+ :rtype: SimplexTree
+
+ :note: The persistence information is not copied. If you need it in the clone, you have to call
+ :func:`compute_persistence` on it even if you had already computed it in the original.
+ """
+ stree = SimplexTree()
+ stree.thisptr = _get_copy_intptr(self)
+ return stree
+
+ def __deepcopy__(self):
+ return self.copy()
+
def filtration(self, simplex):
"""This function returns the filtration value for a given N-simplex in
this simplicial complex, or +infinity if it is not in the complex.
@@ -195,6 +238,91 @@ cdef class SimplexTree:
"""
return self.get_ptr().insert(simplex, <double>filtration)
+ @staticmethod
+ @cython.boundscheck(False)
+ def create_from_array(filtrations, double max_filtration=INFINITY):
+ """Creates a new, empty complex and inserts vertices and edges. The vertices are numbered from 0 to n-1, and
+ the filtration values are encoded in the array, with the diagonal representing the vertices. It is the
+ caller's responsibility to ensure that this defines a filtration, which can be achieved with either::
+
+ filtrations[np.diag_indices_from(filtrations)] = filtrations.min(axis=1)
+
+ or::
+
+ diag = filtrations.diagonal()
+ filtrations = np.fmax(np.fmax(filtrations, diag[:, None]), diag[None, :])
+
+ :param filtrations: the filtration values of the vertices and edges to insert. The matrix is assumed to be symmetric.
+ :type filtrations: numpy.ndarray of shape (n,n)
+ :param max_filtration: only insert vertices and edges with filtration values no larger than max_filtration
+ :type max_filtration: float
+ :returns: the new complex
+ :rtype: SimplexTree
+ """
+ # TODO: document which half of the matrix is actually read?
+ filtrations = np.asanyarray(filtrations, dtype=float)
+ cdef double[:,:] F = filtrations
+ ret = SimplexTree()
+ cdef int n = F.shape[0]
+ assert n == F.shape[1], 'create_from_array() expects a square array'
+ with nogil:
+ ret.get_ptr().insert_matrix(&F[0,0], n, F.strides[0], F.strides[1], max_filtration)
+ return ret
+
+ def insert_edges_from_coo_matrix(self, edges):
+ """Inserts edges given by a sparse matrix in `COOrdinate format
+ <https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html>`_.
+ If an edge is repeated, the smallest filtration value is used. Missing entries are not inserted.
+ Diagonal entries are currently interpreted as vertices, although we do not guarantee this behavior
+ in the future, and this is only useful if you want to insert vertices with a smaller filtration value
+ than the smallest edge containing it, since vertices are implicitly inserted together with the edges.
+
+ :param edges: the edges to insert and their filtration values.
+ :type edges: scipy.sparse.coo_matrix of shape (n,n)
+
+ .. seealso:: :func:`insert_batch`
+ """
+ # Without this, it could be slow if we end up inserting vertices in a bad order (flat_map).
+ self.get_ptr().insert_batch_vertices(np.unique(np.stack((edges.row, edges.col))), INFINITY)
+ # TODO: optimize this?
+ for edge in zip(edges.row, edges.col, edges.data):
+ self.get_ptr().insert((edge[0], edge[1]), edge[2])
+
+ @cython.boundscheck(False)
+ @cython.wraparound(False)
+ def insert_batch(self, some_int[:,:] vertex_array, some_float[:] filtrations):
+ """Inserts k-simplices given by a sparse array in a format similar
+ to `torch.sparse <https://pytorch.org/docs/stable/sparse.html>`_.
+ The n-th simplex has vertices `vertex_array[0,n]`, ...,
+ `vertex_array[k,n]` and filtration value `filtrations[n]`.
+ If a simplex is repeated, the smallest filtration value is used.
+ Simplices with a repeated vertex are currently interpreted as lower
+ dimensional simplices, but we do not guarantee this behavior in the
+ future. Any time a simplex is inserted, its faces are inserted as well
+ if needed to preserve a simplicial complex.
+
+ :param vertex_array: the k-simplices to insert.
+ :type vertex_array: numpy.array of shape (k+1,n)
+ :param filtrations: the filtration values.
+ :type filtrations: numpy.array of shape (n,)
+ """
+ cdef vector[int] vertices = np.unique(vertex_array)
+ cdef Py_ssize_t k = vertex_array.shape[0]
+ cdef Py_ssize_t n = vertex_array.shape[1]
+ assert filtrations.shape[0] == n, 'inconsistent sizes for vertex_array and filtrations'
+ cdef Py_ssize_t i
+ cdef Py_ssize_t j
+ cdef vector[int] v
+ with nogil:
+ # Without this, it could be slow if we end up inserting vertices in a bad order (flat_map).
+ # NaN currently does the wrong thing
+ self.get_ptr().insert_batch_vertices(vertices, INFINITY)
+ for i in range(n):
+ for j in range(k):
+ v.push_back(vertex_array[j, i])
+ self.get_ptr().insert(v, filtrations[i])
+ v.clear()
+
def get_simplices(self):
"""This function returns a generator with simplices and their given
filtration values.
@@ -343,7 +471,7 @@ cdef class SimplexTree:
"""
return self.get_ptr().prune_above_filtration(filtration)
- def expansion(self, max_dim):
+ def expansion(self, max_dimension):
"""Expands the simplex tree containing only its one skeleton
until dimension max_dim.
@@ -357,10 +485,10 @@ cdef class SimplexTree:
The simplex tree must contain no simplex of dimension bigger than
1 when calling the method.
- :param max_dim: The maximal dimension.
- :type max_dim: int
+ :param max_dimension: The maximal dimension.
+ :type max_dimension: int
"""
- cdef int maxdim = max_dim
+ cdef int maxdim = max_dimension
with nogil:
self.get_ptr().expansion(maxdim)
@@ -412,7 +540,7 @@ cdef class SimplexTree:
"""This function retrieves good values for extended persistence, and separate the diagrams into the Ordinary,
Relative, Extended+ and Extended- subdiagrams.
- :param homology_coeff_field: The homology coefficient field. Must be a prime number. Default value is 11.
+ :param homology_coeff_field: The homology coefficient field. Must be a prime number. Default value is 11. Max is 46337.
:type homology_coeff_field: int
:param min_persistence: The minimum persistence value (i.e., the absolute value of the difference between the
persistence diagram point coordinates) to take into account (strictly greater than min_persistence).
@@ -441,15 +569,35 @@ cdef class SimplexTree:
del self.pcohptr
self.pcohptr = new Simplex_tree_persistence_interface(self.get_ptr(), False)
self.pcohptr.compute_persistence(homology_coeff_field, -1.)
- persistence_result = self.pcohptr.get_persistence()
- return self.get_ptr().compute_extended_persistence_subdiagrams(persistence_result, min_persistence)
+ return self.pcohptr.compute_extended_persistence_subdiagrams(min_persistence)
+
+ def expansion_with_blocker(self, max_dim, blocker_func):
+ """Expands the Simplex_tree containing only a graph. Simplices corresponding to cliques in the graph are added
+ incrementally, faces before cofaces, unless the simplex has dimension larger than `max_dim` or `blocker_func`
+ returns `True` for this simplex.
+
+ The function identifies a candidate simplex whose faces are all already in the complex, inserts it with a
+ filtration value corresponding to the maximum of the filtration values of the faces, then calls `blocker_func`
+ with this new simplex (represented as a list of int). If `blocker_func` returns `True`, the simplex is removed,
+ otherwise it is kept. The algorithm then proceeds with the next candidate.
+ .. warning::
+ Several candidates of the same dimension may be inserted simultaneously before calling `blocker_func`, so
+ if you examine the complex in `blocker_func`, you may hit a few simplices of the same dimension that have
+ not been vetted by `blocker_func` yet, or have already been rejected but not yet removed.
+
+ :param max_dim: Expansion maximal dimension value.
+ :type max_dim: int
+ :param blocker_func: Blocker oracle.
+ :type blocker_func: Callable[[List[int]], bool]
+ """
+ self.get_ptr().expansion_with_blockers_callback(max_dim, callback, <void*>blocker_func)
def persistence(self, homology_coeff_field=11, min_persistence=0, persistence_dim_max = False):
"""This function computes and returns the persistence of the simplicial complex.
:param homology_coeff_field: The homology coefficient field. Must be a
- prime number. Default value is 11.
+ prime number. Default value is 11. Max is 46337.
:type homology_coeff_field: int
:param min_persistence: The minimum persistence value to take into
account (strictly greater than min_persistence). Default value is
@@ -472,7 +620,7 @@ cdef class SimplexTree:
when you do not want the list :func:`persistence` returns.
:param homology_coeff_field: The homology coefficient field. Must be a
- prime number. Default value is 11.
+ prime number. Default value is 11. Max is 46337.
:type homology_coeff_field: int
:param min_persistence: The minimum persistence value to take into
account (strictly greater than min_persistence). Default value is
@@ -542,7 +690,11 @@ cdef class SimplexTree:
function to be launched first.
"""
assert self.pcohptr != NULL, "compute_persistence() must be called before persistence_intervals_in_dimension()"
- return np_array(self.pcohptr.intervals_in_dimension(dimension))
+ piid = np.array(self.pcohptr.intervals_in_dimension(dimension))
+ # Workaround https://github.com/GUDHI/gudhi-devel/issues/507
+ if len(piid) == 0:
+ return np.empty(shape = [0, 2])
+ return piid
def persistence_pairs(self):
"""This function returns a list of persistence birth and death simplices pairs.
@@ -583,8 +735,8 @@ cdef class SimplexTree:
"""
assert self.pcohptr != NULL, "lower_star_persistence_generators() requires that persistence() be called first."
gen = self.pcohptr.lower_star_generators()
- normal = [np_array(d).reshape(-1,2) for d in gen.first]
- infinite = [np_array(d) for d in gen.second]
+ normal = [np.array(d).reshape(-1,2) for d in gen.first]
+ infinite = [np.array(d) for d in gen.second]
return (normal, infinite)
def flag_persistence_generators(self):
@@ -602,34 +754,33 @@ cdef class SimplexTree:
assert self.pcohptr != NULL, "flag_persistence_generators() requires that persistence() be called first."
gen = self.pcohptr.flag_generators()
if len(gen.first) == 0:
- normal0 = numpy.empty((0,3))
+ normal0 = np.empty((0,3))
normals = []
else:
l = iter(gen.first)
- normal0 = np_array(next(l)).reshape(-1,3)
- normals = [np_array(d).reshape(-1,4) for d in l]
+ normal0 = np.array(next(l)).reshape(-1,3)
+ normals = [np.array(d).reshape(-1,4) for d in l]
if len(gen.second) == 0:
- infinite0 = numpy.empty(0)
+ infinite0 = np.empty(0)
infinites = []
else:
l = iter(gen.second)
- infinite0 = np_array(next(l))
- infinites = [np_array(d).reshape(-1,2) for d in l]
+ infinite0 = np.array(next(l))
+ infinites = [np.array(d).reshape(-1,2) for d in l]
return (normal0, normals, infinite0, infinites)
def collapse_edges(self, nb_iterations = 1):
- """Assuming the simplex tree is a 1-skeleton graph, this method collapse edges (simplices of higher dimension
- are ignored) and resets the simplex tree from the remaining edges.
- A good candidate is to build a simplex tree on top of a :class:`~gudhi.RipsComplex` of dimension 1 before
- collapsing edges
+ """Assuming the complex is a graph (simplices of higher dimension are ignored), this method implicitly
+ interprets it as the 1-skeleton of a flag complex, and replaces it with another (smaller) graph whose
+ expansion has the same persistent homology, using a technique known as edge collapses
+ (see :cite:`edgecollapsearxiv`).
+
+ A natural application is to get a simplex tree of dimension 1 from :class:`~gudhi.RipsComplex`,
+ then collapse edges, perform :meth:`expansion()` and finally compute persistence
(cf. :download:`rips_complex_edge_collapse_example.py <../example/rips_complex_edge_collapse_example.py>`).
- For implementation details, please refer to :cite:`edgecollapsesocg2020`.
:param nb_iterations: The number of edge collapse iterations to perform. Default is 1.
:type nb_iterations: int
-
- :note: collapse_edges method requires `Eigen <installation.html#eigen>`_ >= 3.1.0 and an exception is thrown
- if this method is not available.
"""
# Backup old pointer
cdef Simplex_tree_interface_full_featured* ptr = self.get_ptr()
@@ -639,3 +790,13 @@ cdef class SimplexTree:
self.thisptr = <intptr_t>(ptr.collapse_edges(nb_iter))
# Delete old pointer
del ptr
+
+ def __eq__(self, other:SimplexTree):
+ """Test for structural equality
+ :returns: True if the 2 simplex trees are equal, False otherwise.
+ :rtype: bool
+ """
+ return dereference(self.get_ptr()) == dereference(other.get_ptr())
+
+cdef intptr_t _get_copy_intptr(SimplexTree stree) nogil:
+ return <intptr_t>(new Simplex_tree_interface_full_featured(dereference(stree.get_ptr())))
diff --git a/src/python/gudhi/sklearn/__init__.py b/src/python/gudhi/sklearn/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/python/gudhi/sklearn/__init__.py
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py
new file mode 100644
index 00000000..672af278
--- /dev/null
+++ b/src/python/gudhi/sklearn/cubical_persistence.py
@@ -0,0 +1,110 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Vincent Rouvreau
+#
+# Copyright (C) 2021 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+from .. import CubicalComplex
+from sklearn.base import BaseEstimator, TransformerMixin
+
+import numpy as np
+# joblib is required by scikit-learn
+from joblib import Parallel, delayed
+
+# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
+# sequenceDiagram
+# USER->>CubicalPersistence: fit_transform(X)
+# CubicalPersistence->>thread1: _tranform(X[0])
+# CubicalPersistence->>thread2: _tranform(X[1])
+# Note right of CubicalPersistence: ...
+# thread1->>CubicalPersistence: [array( H0(X[0]) ), array( H1(X[0]) )]
+# thread2->>CubicalPersistence: [array( H0(X[1]) ), array( H1(X[1]) )]
+# Note right of CubicalPersistence: ...
+# CubicalPersistence->>USER: [[array( H0(X[0]) ), array( H1(X[0]) )],<br/> [array( H0(X[1]) ), array( H1(X[1]) )],<br/> ...]
+
+
+class CubicalPersistence(BaseEstimator, TransformerMixin):
+ """
+ This is a class for computing the persistence diagrams from a cubical complex.
+ """
+
+ def __init__(
+ self,
+ homology_dimensions,
+ newshape=None,
+ homology_coeff_field=11,
+ min_persistence=0.0,
+ n_jobs=None,
+ ):
+ """
+ Constructor for the CubicalPersistence class.
+
+ Parameters:
+ homology_dimensions (int or list of int): The returned persistence diagrams dimension(s).
+ Short circuit the use of :class:`~gudhi.representations.preprocessing.DimensionSelector` when only one
+ dimension matters (in other words, when `homology_dimensions` is an int).
+ newshape (tuple of ints): If cells filtration values require to be reshaped
+ (cf. :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform`), set `newshape`
+ to perform `numpy.reshape(X, newshape, order='C')` in
+ :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform` method.
+ homology_coeff_field (int): The homology coefficient field. Must be a prime number. Default value is 11.
+ min_persistence (float): The minimum persistence value to take into account (strictly greater than
+ `min_persistence`). Default value is `0.0`. Set `min_persistence` to `-1.0` to see all values.
+ n_jobs (int): cf. https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html
+ """
+ self.homology_dimensions = homology_dimensions
+ self.newshape = newshape
+ self.homology_coeff_field = homology_coeff_field
+ self.min_persistence = min_persistence
+ self.n_jobs = n_jobs
+
+ def fit(self, X, Y=None):
+ """
+ Nothing to be done, but useful when included in a scikit-learn Pipeline.
+ """
+ return self
+
+ def __transform(self, cells):
+ cubical_complex = CubicalComplex(top_dimensional_cells=cells)
+ cubical_complex.compute_persistence(
+ homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence
+ )
+ return [
+ cubical_complex.persistence_intervals_in_dimension(dim) for dim in self.homology_dimensions
+ ]
+
+ def __transform_only_this_dim(self, cells):
+ cubical_complex = CubicalComplex(top_dimensional_cells=cells)
+ cubical_complex.compute_persistence(
+ homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence
+ )
+ return cubical_complex.persistence_intervals_in_dimension(self.homology_dimensions)
+
+ def transform(self, X, Y=None):
+ """Compute all the cubical complexes and their associated persistence diagrams.
+
+ :param X: List of cells filtration values (`numpy.reshape(X, newshape, order='C'` if `newshape` is set with a tuple of ints).
+ :type X: list of list of float OR list of numpy.ndarray
+
+ :return: Persistence diagrams in the format:
+
+ - If `homology_dimensions` was set to `n`: `[array( Hn(X[0]) ), array( Hn(X[1]) ), ...]`
+ - If `homology_dimensions` was set to `[i, j]`: `[[array( Hi(X[0]) ), array( Hj(X[0]) )], [array( Hi(X[1]) ), array( Hj(X[1]) )], ...]`
+ :rtype: list of (,2) array_like or list of list of (,2) array_like
+ """
+ if self.newshape is not None:
+ X = np.reshape(X, self.newshape, order='C')
+
+ # Depends on homology_dimensions is an integer or a list of integer (else case)
+ if isinstance(self.homology_dimensions, int):
+ # threads is preferred as cubical construction and persistence computation releases the GIL
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(
+ delayed(self.__transform_only_this_dim)(cells) for cells in X
+ )
+ else:
+ # threads is preferred as cubical construction and persistence computation releases the GIL
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(cells) for cells in X)
+
diff --git a/src/python/gudhi/tensorflow/__init__.py b/src/python/gudhi/tensorflow/__init__.py
new file mode 100644
index 00000000..1599cf52
--- /dev/null
+++ b/src/python/gudhi/tensorflow/__init__.py
@@ -0,0 +1,5 @@
+from .cubical_layer import CubicalLayer
+from .lower_star_simplex_tree_layer import LowerStarSimplexTreeLayer
+from .rips_layer import RipsLayer
+
+__all__ = ["LowerStarSimplexTreeLayer", "RipsLayer", "CubicalLayer"]
diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py
new file mode 100644
index 00000000..5df2c370
--- /dev/null
+++ b/src/python/gudhi/tensorflow/cubical_layer.py
@@ -0,0 +1,82 @@
+import numpy as np
+import tensorflow as tf
+from ..cubical_complex import CubicalComplex
+
+######################
+# Cubical filtration #
+######################
+
+# The parameters of the model are the pixel values.
+
+def _Cubical(Xflat, Xdim, dimensions, homology_coeff_field):
+ # Parameters: Xflat (flattened image),
+ # Xdim (shape of non-flattened image)
+ # dimensions (homology dimensions)
+
+ # Compute the persistence pairs with Gudhi
+ # We reverse the dimensions because CubicalComplex uses Fortran ordering
+ cc = CubicalComplex(dimensions=Xdim[::-1], top_dimensional_cells=Xflat)
+ cc.compute_persistence(homology_coeff_field=homology_coeff_field)
+
+ # Retrieve and output image indices/pixels corresponding to positive and negative simplices
+ cof_pp = cc.cofaces_of_persistence_pairs()
+
+ L_cofs = []
+ for dim in dimensions:
+
+ try:
+ cof = cof_pp[0][dim]
+ except IndexError:
+ cof = np.array([])
+
+ L_cofs.append(np.array(cof, dtype=np.int32))
+
+ return L_cofs
+
+class CubicalLayer(tf.keras.layers.Layer):
+ """
+ TensorFlow layer for computing the persistent homology of a cubical complex
+ """
+ def __init__(self, homology_dimensions, min_persistence=None, homology_coeff_field=11, **kwargs):
+ """
+ Constructor for the CubicalLayer class
+
+ Parameters:
+ homology_dimensions (List[int]): list of homology dimensions
+ min_persistence (List[float]): minimum distance-to-diagonal of the points in the output persistence diagrams (default None, in which case 0. is used for all dimensions)
+ homology_coeff_field (int): homology field coefficient. Must be a prime number. Default value is 11. Max is 46337.
+ """
+ super().__init__(dynamic=True, **kwargs)
+ self.dimensions = homology_dimensions
+ self.min_persistence = min_persistence if min_persistence != None else [0.] * len(self.dimensions)
+ self.hcf = homology_coeff_field
+ assert len(self.min_persistence) == len(self.dimensions)
+
+ def call(self, X):
+ """
+ Compute persistence diagram associated to a cubical complex filtered by some pixel values
+
+ Parameters:
+ X (TensorFlow variable): pixel values of the cubical complex
+
+ Returns:
+ List[Tuple[tf.Tensor,tf.Tensor]]: List of cubical persistence diagrams. The length of this list is the same than that of dimensions, i.e., there is one persistence diagram per homology dimension provided in the input list dimensions. Moreover, the finite and essential parts of the persistence diagrams are provided separately: each element of this list is a tuple of size two that contains the finite and essential parts of the corresponding persistence diagram, of shapes [num_finite_points, 2] and [num_essential_points, 1] respectively. Note that the essential part is always empty in cubical persistence diagrams, except in homology dimension zero, where the essential part always contains a single point, with abscissa equal to the smallest value in the complex, and infinite ordinate
+ """
+ # Compute pixels associated to positive and negative simplices
+ # Don't compute gradient for this operation
+ Xflat = tf.reshape(X, [-1])
+ Xdim, Xflat_numpy = X.shape, Xflat.numpy()
+ indices_list = _Cubical(Xflat_numpy, Xdim, self.dimensions, self.hcf)
+ index_essential = np.argmin(Xflat_numpy) # index of minimum pixel value for essential persistence diagram
+ # Get persistence diagram by simply picking the corresponding entries in the image
+ self.dgms = []
+ for idx_dim, dimension in enumerate(self.dimensions):
+ finite_dgm = tf.reshape(tf.gather(Xflat, indices_list[idx_dim]), [-1,2])
+ essential_dgm = tf.reshape(tf.gather(Xflat, index_essential), [-1,1]) if dimension == 0 else tf.zeros([0, 1])
+ min_pers = self.min_persistence[idx_dim]
+ if min_pers >= 0:
+ persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
+ self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices), [-1,2]), essential_dgm))
+ else:
+ self.dgms.append((finite_dgm, essential_dgm))
+ return self.dgms
diff --git a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
new file mode 100644
index 00000000..5a8e5b75
--- /dev/null
+++ b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
@@ -0,0 +1,87 @@
+import numpy as np
+import tensorflow as tf
+
+#########################################
+# Lower star filtration on simplex tree #
+#########################################
+
+# The parameters of the model are the vertex function values of the simplex tree.
+
+def _LowerStarSimplexTree(simplextree, filtration, dimensions, homology_coeff_field):
+ # Parameters: simplextree (simplex tree on which to compute persistence)
+ # filtration (function values on the vertices of st),
+ # dimensions (homology dimensions),
+ # homology_coeff_field (homology field coefficient)
+
+ simplextree.reset_filtration(-np.inf, 0)
+
+ # Assign new filtration values
+ for i in range(simplextree.num_vertices()):
+ simplextree.assign_filtration([i], filtration[i])
+ simplextree.make_filtration_non_decreasing()
+
+ # Compute persistence diagram
+ simplextree.compute_persistence(homology_coeff_field=homology_coeff_field)
+
+ # Get vertex pairs for optimization. First, get all simplex pairs
+ pairs = simplextree.lower_star_persistence_generators()
+
+ L_indices = []
+ for dimension in dimensions:
+
+ finite_pairs = pairs[0][dimension] if len(pairs[0]) >= dimension+1 else np.empty(shape=[0,2])
+ essential_pairs = pairs[1][dimension] if len(pairs[1]) >= dimension+1 else np.empty(shape=[0,1])
+
+ finite_indices = np.array(finite_pairs.flatten(), dtype=np.int32)
+ essential_indices = np.array(essential_pairs.flatten(), dtype=np.int32)
+
+ L_indices.append((finite_indices, essential_indices))
+
+ return L_indices
+
+class LowerStarSimplexTreeLayer(tf.keras.layers.Layer):
+ """
+ TensorFlow layer for computing lower-star persistence out of a simplex tree
+ """
+ def __init__(self, simplextree, homology_dimensions, min_persistence=None, homology_coeff_field=11, **kwargs):
+ """
+ Constructor for the LowerStarSimplexTreeLayer class
+
+ Parameters:
+ simplextree (gudhi.SimplexTree): underlying simplex tree. Its vertices MUST be named with integers from 0 to n-1, where n is its number of vertices. Note that its filtration values are modified in each call of the class.
+ homology_dimensions (List[int]): list of homology dimensions
+ min_persistence (List[float]): minimum distance-to-diagonal of the points in the output persistence diagrams (default None, in which case 0. is used for all dimensions)
+ homology_coeff_field (int): homology field coefficient. Must be a prime number. Default value is 11. Max is 46337.
+ """
+ super().__init__(dynamic=True, **kwargs)
+ self.dimensions = homology_dimensions
+ self.simplextree = simplextree
+ self.min_persistence = min_persistence if min_persistence != None else [0. for _ in range(len(self.dimensions))]
+ self.hcf = homology_coeff_field
+ assert len(self.min_persistence) == len(self.dimensions)
+
+ def call(self, filtration):
+ """
+ Compute lower-star persistence diagram associated to a function defined on the vertices of the simplex tree
+
+ Parameters:
+ F (TensorFlow variable): filter function values over the vertices of the simplex tree. The ith entry of F corresponds to vertex i in self.simplextree
+
+ Returns:
+ List[Tuple[tf.Tensor,tf.Tensor]]: List of lower-star persistence diagrams. The length of this list is the same than that of dimensions, i.e., there is one persistence diagram per homology dimension provided in the input list dimensions. Moreover, the finite and essential parts of the persistence diagrams are provided separately: each element of this list is a tuple of size two that contains the finite and essential parts of the corresponding persistence diagram, of shapes [num_finite_points, 2] and [num_essential_points, 1] respectively
+ """
+ # Don't try to compute gradients for the vertex pairs
+ indices = _LowerStarSimplexTree(self.simplextree, filtration.numpy(), self.dimensions, self.hcf)
+ # Get persistence diagrams
+ self.dgms = []
+ for idx_dim, dimension in enumerate(self.dimensions):
+ finite_dgm = tf.reshape(tf.gather(filtration, indices[idx_dim][0]), [-1,2])
+ essential_dgm = tf.reshape(tf.gather(filtration, indices[idx_dim][1]), [-1,1])
+ min_pers = self.min_persistence[idx_dim]
+ if min_pers >= 0:
+ persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
+ self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices),[-1,2]), essential_dgm))
+ else:
+ self.dgms.append((finite_dgm, essential_dgm))
+ return self.dgms
+
diff --git a/src/python/gudhi/tensorflow/perslay.py b/src/python/gudhi/tensorflow/perslay.py
new file mode 100644
index 00000000..9976c5f3
--- /dev/null
+++ b/src/python/gudhi/tensorflow/perslay.py
@@ -0,0 +1,284 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Mathieu Carrière
+#
+# Copyright (C) 2021 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+import tensorflow as tf
+import math
+
+class GridPerslayWeight(tf.keras.layers.Layer):
+ """
+ This is a class for computing a differentiable weight function for persistence diagram points. This function is defined from an array that contains its values on a 2D grid.
+ """
+ def __init__(self, grid, grid_bnds, **kwargs):
+ """
+ Constructor for the GridPerslayWeight class.
+
+ Parameters:
+ grid (n x n numpy array): grid of values.
+ grid_bnds (2 x 2 numpy array): boundaries of the grid, of the form [[min_x, max_x], [min_y, max_y]].
+ """
+ super().__init__(dynamic=True, **kwargs)
+ self.grid = tf.Variable(initial_value=grid, trainable=True)
+ self.grid_bnds = grid_bnds
+
+ def build(self, input_shape):
+ return self
+
+ def call(self, diagrams):
+ """
+ Apply GridPerslayWeight on a ragged tensor containing a list of persistence diagrams.
+
+ Parameters:
+ diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.
+
+ Returns:
+ weight (n x None): ragged tensor containing the weights of the points in the n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.
+ """
+ grid_shape = self.grid.shape
+ indices = []
+ for dim in range(2):
+ [m,M] = self.grid_bnds[dim]
+ coords = tf.expand_dims(diagrams[:,:,dim],-1)
+ ids = grid_shape[dim]*(coords-m)/(M-m)
+ indices.append(tf.cast(ids, tf.int32))
+ weight = tf.gather_nd(params=self.grid, indices=tf.concat(indices, axis=2))
+ return weight
+
+class GaussianMixturePerslayWeight(tf.keras.layers.Layer):
+ """
+ This is a class for computing a differentiable weight function for persistence diagram points. This function is defined from a mixture of Gaussian functions.
+ """
+ def __init__(self, gaussians, **kwargs):
+ """
+ Constructor for the GridPerslayWeight class.
+
+ Parameters:
+ gaussians (4 x n numpy array): parameters of the n Gaussian functions, of the form transpose([[mu_x^1, mu_y^1, sigma_x^1, sigma_y^1], ..., [mu_x^n, mu_y^n, sigma_x^n, sigma_y^n]]).
+ """
+ super().__init__(dynamic=True, **kwargs)
+ self.W = tf.Variable(initial_value=gaussians, trainable=True)
+
+ def build(self, input_shape):
+ return self
+
+ def call(self, diagrams):
+ """
+ Apply GaussianMixturePerslayWeight on a ragged tensor containing a list of persistence diagrams.
+
+ Parameters:
+ diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.
+
+ Returns:
+ weight (n x None): ragged tensor containing the weights of the points in the n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.
+ """
+ means = tf.expand_dims(tf.expand_dims(self.W[:2,:],0),0)
+ variances = tf.expand_dims(tf.expand_dims(self.W[2:,:],0),0)
+ diagrams = tf.expand_dims(diagrams, -1)
+ dists = tf.math.multiply(tf.math.square(diagrams-means), 1/tf.math.square(variances))
+ weight = tf.math.reduce_sum(tf.math.exp(tf.math.reduce_sum(-dists, axis=2)), axis=2)
+ return weight
+
+class PowerPerslayWeight(tf.keras.layers.Layer):
+ """
+ This is a class for computing a differentiable weight function for persistence diagram points. This function is defined as a constant multiplied by the distance to the diagonal of the persistence diagram point raised to some power.
+ """
+ def __init__(self, constant, power, **kwargs):
+ """
+ Constructor for the PowerPerslayWeight class.
+
+ Parameters:
+ constant (float): constant value.
+ power (float): power applied to the distance to the diagonal.
+ """
+ super().__init__(dynamic=True, **kwargs)
+ self.constant = tf.Variable(initial_value=constant, trainable=True)
+ self.power = power
+
+ def build(self, input_shape):
+ return self
+
+ def call(self, diagrams):
+ """
+ Apply PowerPerslayWeight on a ragged tensor containing a list of persistence diagrams.
+
+ Parameters:
+ diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.
+
+ Returns:
+ weight (n x None): ragged tensor containing the weights of the points in the n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.
+ """
+ weight = self.constant * tf.math.pow(tf.math.abs(diagrams[:,:,1]-diagrams[:,:,0]), self.power)
+ return weight
+
+
+class GaussianPerslayPhi(tf.keras.layers.Layer):
+ """
+ This is a class for computing a transformation function for persistence diagram points. This function turns persistence diagram points into 2D Gaussian functions centered on the points, that are then evaluated on a regular 2D grid.
+ """
+ def __init__(self, image_size, image_bnds, variance, **kwargs):
+ """
+ Constructor for the GaussianPerslayPhi class.
+
+ Parameters:
+ image_size (int numpy array): number of grid elements on each grid axis, of the form [n_x, n_y].
+ image_bnds (2 x 2 numpy array): boundaries of the grid, of the form [[min_x, max_x], [min_y, max_y]].
+ variance (float): variance of the Gaussian functions.
+ """
+ super().__init__(dynamic=True, **kwargs)
+ self.image_size = image_size
+ self.image_bnds = image_bnds
+ self.variance = tf.Variable(initial_value=variance, trainable=True)
+
+ def build(self, input_shape):
+ return self
+
+ def call(self, diagrams):
+ """
+ Apply GaussianPerslayPhi on a ragged tensor containing a list of persistence diagrams.
+
+ Parameters:
+ diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.
+
+ Returns:
+ output (n x None x image_size x image_size x 1): ragged tensor containing the evaluations on the 2D grid of the 2D Gaussian functions corresponding to the persistence diagram points, in the form of a 2D image with 1 channel that can be processed with, e.g., convolutional layers. The second dimension is ragged since persistence diagrams can have different numbers of points.
+ output_shape (int numpy array): shape of the output tensor.
+ """
+ diagrams_d = tf.concat([diagrams[:,:,0:1], diagrams[:,:,1:2]-diagrams[:,:,0:1]], axis=2)
+ step = [(self.image_bnds[i][1]-self.image_bnds[i][0])/self.image_size[i] for i in range(2)]
+ coords = [tf.range(self.image_bnds[i][0], self.image_bnds[i][1], step[i]) for i in range(2)]
+ M = tf.meshgrid(*coords)
+ mu = tf.concat([tf.expand_dims(tens, 0) for tens in M], axis=0)
+ for _ in range(2):
+ diagrams_d = tf.expand_dims(diagrams_d,-1)
+ dists = tf.math.square(diagrams_d-mu) / (2*tf.math.square(self.variance))
+ gauss = tf.math.exp(tf.math.reduce_sum(-dists, axis=2)) / (2*math.pi*tf.math.square(self.variance))
+ output = tf.expand_dims(gauss,-1)
+ output_shape = M[0].shape + tuple([1])
+ return output, output_shape
+
+class TentPerslayPhi(tf.keras.layers.Layer):
+ """
+ This is a class for computing a transformation function for persistence diagram points. This function turns persistence diagram points into 1D tent functions (linearly increasing on the first half of the bar corresponding to the point from zero to half of the bar length, linearly decreasing on the second half and zero elsewhere) centered on the points, that are then evaluated on a regular 1D grid.
+ """
+ def __init__(self, samples, **kwargs):
+ """
+ Constructor for the GaussianPerslayPhi class.
+
+ Parameters:
+ samples (float numpy array): grid elements on which to evaluate the tent functions, of the form [x_1, ..., x_n].
+ """
+ super().__init__(dynamic=True, **kwargs)
+ self.samples = tf.Variable(initial_value=samples, trainable=True)
+
+ def build(self, input_shape):
+ return self
+
+ def call(self, diagrams):
+ """
+ Apply TentPerslayPhi on a ragged tensor containing a list of persistence diagrams.
+
+ Parameters:
+ diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.
+
+ Returns:
+ output (n x None x num_samples): ragged tensor containing the evaluations on the 1D grid of the 1D tent functions corresponding to the persistence diagram points. The second dimension is ragged since persistence diagrams can have different numbers of points.
+ output_shape (int numpy array): shape of the output tensor.
+ """
+ samples_d = tf.expand_dims(tf.expand_dims(self.samples,0),0)
+ xs, ys = diagrams[:,:,0:1], diagrams[:,:,1:2]
+ output = tf.math.maximum(.5*(ys-xs) - tf.math.abs(samples_d-.5*(ys+xs)), tf.constant([0.]))
+ output_shape = self.samples.shape
+ return output, output_shape
+
+class FlatPerslayPhi(tf.keras.layers.Layer):
+ """
+ This is a class for computing a transformation function for persistence diagram points. This function turns persistence diagram points into 1D constant functions (that evaluate to half of the bar length on the bar corresponding to the point and zero elsewhere), that are then evaluated on a regular 1D grid.
+ """
+ def __init__(self, samples, theta, **kwargs):
+ """
+ Constructor for the FlatPerslayPhi class.
+
+ Parameters:
+ samples (float numpy array): grid elements on which to evaluate the constant functions, of the form [x_1, ..., x_n].
+ theta (float): sigmoid parameter used to approximate the constant function with a differentiable sigmoid function. The bigger the theta, the closer to a constant function the output will be.
+ """
+ super().__init__(dynamic=True, **kwargs)
+ self.samples = tf.Variable(initial_value=samples, trainable=True)
+ self.theta = tf.Variable(initial_value=theta, trainable=True)
+
+ def build(self, input_shape):
+ return self
+
+ def call(self, diagrams):
+ """
+ Apply FlatPerslayPhi on a ragged tensor containing a list of persistence diagrams.
+
+ Parameters:
+ diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.
+
+ Returns:
+ output (n x None x num_samples): ragged tensor containing the evaluations on the 1D grid of the 1D constant functions corresponding to the persistence diagram points. The second dimension is ragged since persistence diagrams can have different numbers of points.
+ output_shape (int numpy array): shape of the output tensor.
+ """
+ samples_d = tf.expand_dims(tf.expand_dims(self.samples,0),0)
+ xs, ys = diagrams[:,:,0:1], diagrams[:,:,1:2]
+ output = 1./(1.+tf.math.exp(-self.theta*(.5*(ys-xs)-tf.math.abs(samples_d-.5*(ys+xs)))))
+ output_shape = self.samples.shape
+ return output, output_shape
+
+class Perslay(tf.keras.layers.Layer):
+ """
+ This is a TensorFlow layer for vectorizing persistence diagrams in a differentiable way within a neural network. This function implements the PersLay equation, see `the corresponding article <http://proceedings.mlr.press/v108/carriere20a.html>`_.
+ """
+ def __init__(self, weight, phi, perm_op, rho, **kwargs):
+ """
+ Constructor for the Perslay class.
+
+ Parameters:
+ weight (function): weight function for the persistence diagram points. Can be either :class:`~gudhi.tensorflow.perslay.GridPerslayWeight`, :class:`~gudhi.tensorflow.perslay.GaussianMixturePerslayWeight`, :class:`~gudhi.tensorflow.perslay.PowerPerslayWeight`, or a custom TensorFlow function that takes persistence diagrams as argument (represented as an (n x None x 2) ragged tensor, where n is the number of diagrams).
+ phi (function): transformation function for the persistence diagram points. Can be either :class:`~gudhi.tensorflow.perslay.GaussianPerslayPhi`, :class:`~gudhi.tensorflow.perslay.TentPerslayPhi`, :class:`~gudhi.tensorflow.perslay.FlatPerslayPhi`, or a custom TensorFlow class (that can have trainable parameters) with a method `call` that takes persistence diagrams as argument (represented as an (n x None x 2) ragged tensor, where n is the number of diagrams).
+ perm_op (function): permutation invariant function, such as `tf.math.reduce_sum`, `tf.math.reduce_mean`, `tf.math.reduce_max`, `tf.math.reduce_min`, or a custom TensorFlow function that takes two arguments: a tensor and an axis on which to apply the permutation invariant operation. If perm_op is the string "topk" (where k is a number), this function will be computed as `tf.math.top_k` with parameter `int(k)`.
+ rho (function): postprocessing function that is applied after the permutation invariant operation. Can be any TensorFlow layer.
+ """
+ super().__init__(dynamic=True, **kwargs)
+ self.weight = weight
+ self.phi = phi
+ self.perm_op = perm_op
+ self.rho = rho
+
+ def build(self, input_shape):
+ return self
+
+ def call(self, diagrams):
+ """
+ Apply Perslay on a ragged tensor containing a list of persistence diagrams.
+
+ Parameters:
+ diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.
+
+ Returns:
+ vector (n x output_shape): tensor containing the vectorizations of the persistence diagrams.
+ """
+ vector, dim = self.phi(diagrams)
+ weight = self.weight(diagrams)
+ for _ in range(len(dim)):
+ weight = tf.expand_dims(weight, -1)
+ vector = tf.math.multiply(vector, weight)
+
+ permop = self.perm_op
+ if type(permop) == str and permop[:3] == 'top':
+ k = int(permop[3:])
+ vector = vector.to_tensor(default_value=-1e10)
+ vector = tf.math.top_k(tf.transpose(vector, perm=[0, 2, 1]), k=k).values
+ vector = tf.reshape(vector, [-1,k*dim[0]])
+ else:
+ vector = permop(vector, axis=1)
+
+ vector = self.rho(vector)
+
+ return vector
diff --git a/src/python/gudhi/tensorflow/rips_layer.py b/src/python/gudhi/tensorflow/rips_layer.py
new file mode 100644
index 00000000..2a73472c
--- /dev/null
+++ b/src/python/gudhi/tensorflow/rips_layer.py
@@ -0,0 +1,93 @@
+import numpy as np
+import tensorflow as tf
+from ..rips_complex import RipsComplex
+
+############################
+# Vietoris-Rips filtration #
+############################
+
+# The parameters of the model are the point coordinates.
+
+def _Rips(DX, max_edge, dimensions, homology_coeff_field):
+ # Parameters: DX (distance matrix),
+ # max_edge (maximum edge length for Rips filtration),
+ # dimensions (homology dimensions)
+
+ # Compute the persistence pairs with Gudhi
+ rc = RipsComplex(distance_matrix=DX, max_edge_length=max_edge)
+ st = rc.create_simplex_tree(max_dimension=max(dimensions)+1)
+ st.compute_persistence(homology_coeff_field=homology_coeff_field)
+ pairs = st.flag_persistence_generators()
+
+ L_indices = []
+ for dimension in dimensions:
+
+ if dimension == 0:
+ finite_pairs = pairs[0]
+ essential_pairs = pairs[2]
+ else:
+ finite_pairs = pairs[1][dimension-1] if len(pairs[1]) >= dimension else np.empty(shape=[0,4])
+ essential_pairs = pairs[3][dimension-1] if len(pairs[3]) >= dimension else np.empty(shape=[0,2])
+
+ finite_indices = np.array(finite_pairs.flatten(), dtype=np.int32)
+ essential_indices = np.array(essential_pairs.flatten(), dtype=np.int32)
+
+ L_indices.append((finite_indices, essential_indices))
+
+ return L_indices
+
+class RipsLayer(tf.keras.layers.Layer):
+ """
+ TensorFlow layer for computing Rips persistence out of a point cloud
+ """
+ def __init__(self, homology_dimensions, maximum_edge_length=np.inf, min_persistence=None, homology_coeff_field=11, **kwargs):
+ """
+ Constructor for the RipsLayer class
+
+ Parameters:
+ maximum_edge_length (float): maximum edge length for the Rips complex
+ homology_dimensions (List[int]): list of homology dimensions
+ min_persistence (List[float]): minimum distance-to-diagonal of the points in the output persistence diagrams (default None, in which case 0. is used for all dimensions)
+ homology_coeff_field (int): homology field coefficient. Must be a prime number. Default value is 11. Max is 46337.
+ """
+ super().__init__(dynamic=True, **kwargs)
+ self.max_edge = maximum_edge_length
+ self.dimensions = homology_dimensions
+ self.min_persistence = min_persistence if min_persistence != None else [0. for _ in range(len(self.dimensions))]
+ self.hcf = homology_coeff_field
+ assert len(self.min_persistence) == len(self.dimensions)
+
+ def call(self, X):
+ """
+ Compute Rips persistence diagram associated to a point cloud
+
+ Parameters:
+ X (TensorFlow variable): point cloud of shape [number of points, number of dimensions]
+
+ Returns:
+ List[Tuple[tf.Tensor,tf.Tensor]]: List of Rips persistence diagrams. The length of this list is the same than that of dimensions, i.e., there is one persistence diagram per homology dimension provided in the input list dimensions. Moreover, the finite and essential parts of the persistence diagrams are provided separately: each element of this list is a tuple of size two that contains the finite and essential parts of the corresponding persistence diagram, of shapes [num_finite_points, 2] and [num_essential_points, 1] respectively
+ """
+ # Compute distance matrix
+ DX = tf.norm(tf.expand_dims(X, 1)-tf.expand_dims(X, 0), axis=2)
+ # Compute vertices associated to positive and negative simplices
+ # Don't compute gradient for this operation
+ indices = _Rips(DX.numpy(), self.max_edge, self.dimensions, self.hcf)
+ # Get persistence diagrams by simply picking the corresponding entries in the distance matrix
+ self.dgms = []
+ for idx_dim, dimension in enumerate(self.dimensions):
+ cur_idx = indices[idx_dim]
+ if dimension > 0:
+ finite_dgm = tf.reshape(tf.gather_nd(DX, tf.reshape(cur_idx[0], [-1,2])), [-1,2])
+ essential_dgm = tf.reshape(tf.gather_nd(DX, tf.reshape(cur_idx[1], [-1,2])), [-1,1])
+ else:
+ reshaped_cur_idx = tf.reshape(cur_idx[0], [-1,3])
+ finite_dgm = tf.concat([tf.zeros([reshaped_cur_idx.shape[0],1]), tf.reshape(tf.gather_nd(DX, reshaped_cur_idx[:,1:]), [-1,1])], axis=1)
+ essential_dgm = tf.zeros([cur_idx[1].shape[0],1])
+ min_pers = self.min_persistence[idx_dim]
+ if min_pers >= 0:
+ persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
+ self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices),[-1,2]), essential_dgm))
+ else:
+ self.dgms.append((finite_dgm, essential_dgm))
+ return self.dgms
+
diff --git a/src/python/gudhi/wasserstein/barycenter.py b/src/python/gudhi/wasserstein/barycenter.py
index d67bcde7..bb6e641e 100644
--- a/src/python/gudhi/wasserstein/barycenter.py
+++ b/src/python/gudhi/wasserstein/barycenter.py
@@ -37,7 +37,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
:param init: The initial value for barycenter estimate.
If ``None``, init is made on a random diagram from the dataset.
Otherwise, it can be an ``int`` (then initialization is made on ``pdiagset[init]``)
- or a `(n x 2)` ``numpy.array`` enconding a persistence diagram with `n` points.
+ or a `(n x 2)` ``numpy.array`` encoding a persistence diagram with `n` points.
:type init: ``int``, or (n x 2) ``np.array``
:param verbose: if ``True``, returns additional information about the barycenter.
:type verbose: boolean
@@ -45,7 +45,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
(local minimum of the energy function).
If ``pdiagset`` is empty, returns ``None``.
If verbose, returns a couple ``(Y, log)`` where ``Y`` is the barycenter estimate,
- and ``log`` is a ``dict`` that contains additional informations:
+ and ``log`` is a ``dict`` that contains additional information:
- `"groupings"`, a list of list of pairs ``(i,j)``. Namely, ``G[k] = [...(i, j)...]``, where ``(i,j)`` indicates that `pdiagset[k][i]`` is matched to ``Y[j]`` if ``i = -1`` or ``j = -1``, it means they represent the diagonal.
@@ -73,7 +73,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
nb_iter = 0
- converged = False # stoping criterion
+ converged = False # stopping criterion
while not converged:
nb_iter += 1
K = len(Y) # current nb of points in Y (some might be on diagonal)
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index a9d1cdff..dc18806e 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -9,6 +9,7 @@
import numpy as np
import scipy.spatial.distance as sc
+import warnings
try:
import ot
@@ -70,6 +71,7 @@ def _perstot_autodiff(X, order, internal_p):
'''
return _dist_to_diag(X, internal_p).norms.lp(order)
+
def _perstot(X, order, internal_p, enable_autodiff):
'''
:param X: (n x 2) numpy.array (points of a given diagram).
@@ -79,6 +81,9 @@ def _perstot(X, order, internal_p, enable_autodiff):
transparent to automatic differentiation.
:type enable_autodiff: bool
:returns: float, the total persistence of the diagram (that is, its distance to the empty diagram).
+
+ .. note::
+ Can be +inf if the diagram has an essential part (points with infinite coordinates).
'''
if enable_autodiff:
import eagerpy as ep
@@ -88,32 +93,163 @@ def _perstot(X, order, internal_p, enable_autodiff):
return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order)
-def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False):
+def _get_essential_parts(a):
'''
- :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
- (i.e. with infinite coordinate).
- :param Y: (m x 2) numpy.array encoding the second diagram.
- :param matching: if True, computes and returns the optimal matching between X and Y, encoded as
- a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to
- the j-th point in Y, with the convention (-1) represents the diagonal.
- :param order: exponent for Wasserstein; Default value is 1.
- :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
- Default value is `np.inf`.
- :param enable_autodiff: If X and Y are torch.tensor or tensorflow.Tensor, make the computation
+ :param a: (n x 2) numpy.array (point of a diagram)
+ :returns: five lists of indices (between 0 and len(a)) accounting for the five types of points with infinite
+ coordinates that can occur in a diagram, namely:
+ type0 : (-inf, finite)
+ type1 : (finite, +inf)
+ type2 : (-inf, +inf)
+ type3 : (-inf, -inf)
+ type4 : (+inf, +inf)
+ .. note::
+ For instance, a[_get_essential_parts(a)[0]] returns the points in a of coordinates (-inf, x) for some finite x.
+ Note also that points with (+inf, -inf) are not handled (points (x,y) in dgm satisfy by assumption (y >= x)).
+
+ Finally, we consider that points with coordinates (-inf,-inf) and (+inf, +inf) belong to the diagonal.
+ '''
+ if len(a):
+ first_coord_finite = np.isfinite(a[:,0])
+ second_coord_finite = np.isfinite(a[:,1])
+ first_coord_infinite_positive = (a[:,0] == np.inf)
+ second_coord_infinite_positive = (a[:,1] == np.inf)
+ first_coord_infinite_negative = (a[:,0] == -np.inf)
+ second_coord_infinite_negative = (a[:,1] == -np.inf)
+
+ ess_first_type = np.where(second_coord_finite & first_coord_infinite_negative)[0] # coord (-inf, x)
+ ess_second_type = np.where(first_coord_finite & second_coord_infinite_positive)[0] # coord (x, +inf)
+ ess_third_type = np.where(first_coord_infinite_negative & second_coord_infinite_positive)[0] # coord (-inf, +inf)
+
+ ess_fourth_type = np.where(first_coord_infinite_negative & second_coord_infinite_negative)[0] # coord (-inf, -inf)
+ ess_fifth_type = np.where(first_coord_infinite_positive & second_coord_infinite_positive)[0] # coord (+inf, +inf)
+ return ess_first_type, ess_second_type, ess_third_type, ess_fourth_type, ess_fifth_type
+ else:
+ return [], [], [], [], []
+
+
+def _cost_and_match_essential_parts(X, Y, idX, idY, order, axis):
+ '''
+ :param X: (n x 2) numpy.array (dgm points)
+ :param Y: (n x 2) numpy.array (dgm points)
+ :param idX: indices to consider for this one dimensional OT problem (in X)
+ :param idY: indices to consider for this one dimensional OT problem (in Y)
+ :param order: exponent for Wasserstein distance computation
+ :param axis: must be 0 or 1, correspond to the coordinate which is finite.
+ :returns: cost (float) and match for points with *one* infinite coordinate.
+
+ .. note::
+ Assume idX, idY come when calling _handle_essential_parts, thus have same length.
+ '''
+ u = X[idX, axis]
+ v = Y[idY, axis]
+
+ cost = np.sum(np.abs(np.sort(u) - np.sort(v))**(order)) # OT cost in 1D
+
+ sortidX = idX[np.argsort(u)]
+ sortidY = idY[np.argsort(v)]
+ # We return [i,j] sorted per value
+ match = list(zip(sortidX, sortidY))
+
+ return cost, match
+
+
+def _handle_essential_parts(X, Y, order):
+ '''
+ :param X: (n x 2) numpy array, first diagram.
+ :param Y: (n x 2) numpy array, second diagram.
+ :order: Wasserstein order for cost computation.
+ :returns: cost and matching due to essential parts. If cost is +inf, matching will be set to None.
+ '''
+ ess_parts_X = _get_essential_parts(X)
+ ess_parts_Y = _get_essential_parts(Y)
+
+ # Treats the case of infinite cost (cardinalities of essential parts differ).
+ for u, v in list(zip(ess_parts_X, ess_parts_Y))[:3]: # ignore types 4 and 5 as they belong to the diagonal
+ if len(u) != len(v):
+ return np.inf, None
+
+ # Now we know each essential part has the same number of points in both diagrams.
+ # Handle type 0 and type 1 essential parts (those with one finite coordinates)
+ c1, m1 = _cost_and_match_essential_parts(X, Y, ess_parts_X[0], ess_parts_Y[0], axis=1, order=order)
+ c2, m2 = _cost_and_match_essential_parts(X, Y, ess_parts_X[1], ess_parts_Y[1], axis=0, order=order)
+
+ c = c1 + c2
+ m = m1 + m2
+
+ # Handle type3 (coordinates (-inf,+inf), so we just align points)
+ m += list(zip(ess_parts_X[2], ess_parts_Y[2]))
+
+ # Handle type 4 and 5, considered as belonging to the diagonal so matched to (-1) with cost 0.
+ for z in ess_parts_X[3:]:
+ m += [(u, -1) for u in z] # points in X are matched to -1
+ for z in ess_parts_Y[3:]:
+ m += [(-1, v) for v in z] # -1 is match to points in Y
+
+ return c, np.array(m)
+
+
+def _finite_part(X):
+ '''
+ :param X: (n x 2) numpy array encoding a persistence diagram.
+ :returns: The finite part of a diagram `X` (points with finite coordinates).
+ '''
+ return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))]
+
+
+def _warn_infty(matching):
+ '''
+ Handle essential parts with different cardinalities. Warn the user about cost being infinite and (if
+ `matching=True`) about the returned matching being `None`.
+ '''
+ if matching:
+ warnings.warn('Cardinality of essential parts differs. Distance (cost) is +inf, and the returned matching is None.')
+ return np.inf, None
+ else:
+ warnings.warn('Cardinality of essential parts differs. Distance (cost) is +inf.')
+ return np.inf
+
+
+def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False,
+ keep_essential_parts=True):
+ '''
+ Compute the Wasserstein distance between persistence diagram using Python Optimal Transport backend.
+ Diagrams can contain points with infinity coordinates (essential parts).
+ Points with (-inf,-inf) and (+inf,+inf) coordinates are considered as belonging to the diagonal.
+ If the distance between two diagrams is +inf (which happens if the cardinalities of essential
+ parts differ) and optimal matching is required, it will be set to ``None``.
+
+ :param X: The first diagram.
+ :type X: n x 2 numpy.array
+ :param Y: The second diagram.
+ :type Y: m x 2 numpy.array
+ :param matching: if ``True``, computes and returns the optimal matching between X and Y, encoded as
+ a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to
+ the j-th point in Y, with the convention that (-1) represents the diagonal.
+ :param order: Wasserstein exponent q (1 <= q < infinity).
+ :type order: float
+ :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2).
+ :type internal_p: float
+ :param enable_autodiff: If X and Y are ``torch.tensor`` or ``tensorflow.Tensor``, make the computation
transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible
- with `matching=True`.
+ with ``matching=True`` and with ``keep_essential_parts=True``.
- .. note:: This considers the function defined on the coordinates of the off-diagonal points of X and Y
+ .. note:: This considers the function defined on the coordinates of the off-diagonal finite points of X and Y
and lets the various frameworks compute its gradient. It never pulls new points from the diagonal.
:type enable_autodiff: bool
- :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with
+ :param keep_essential_parts: If ``False``, only considers the finite points in the diagrams.
+ Otherwise, include essential parts in cost and matching computation.
+ :type keep_essential_parts: bool
+ :returns: The Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with
respect to the internal_p-norm as ground metric.
If matching is set to True, also returns the optimal matching between X and Y.
+ If cost is +inf, any matching is optimal and thus it returns `None` instead.
'''
+
+ # First step: handle empty diagrams
n = len(X)
m = len(Y)
- # handle empty diagrams
if n == 0:
if m == 0:
if not matching:
@@ -122,16 +258,45 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
else:
return 0., np.array([])
else:
- if not matching:
- return _perstot(Y, order, internal_p, enable_autodiff)
+ cost = _perstot(Y, order, internal_p, enable_autodiff)
+ if cost == np.inf:
+ return _warn_infty(matching)
else:
- return _perstot(Y, order, internal_p, enable_autodiff), np.array([[-1, j] for j in range(m)])
+ if not matching:
+ return cost
+ else:
+ return cost, np.array([[-1, j] for j in range(m)])
elif m == 0:
- if not matching:
- return _perstot(X, order, internal_p, enable_autodiff)
+ cost = _perstot(X, order, internal_p, enable_autodiff)
+ if cost == np.inf:
+ return _warn_infty(matching)
else:
- return _perstot(X, order, internal_p, enable_autodiff), np.array([[i, -1] for i in range(n)])
+ if not matching:
+ return cost
+ else:
+ return cost, np.array([[i, -1] for i in range(n)])
+
+ # Check essential part and enable autodiff together
+ if enable_autodiff and keep_essential_parts:
+ warnings.warn('''enable_autodiff=True and keep_essential_parts=True are incompatible together.
+ keep_essential_parts is set to False: only points with finite coordinates are considered
+ in the following.
+ ''')
+ keep_essential_parts = False
+
+ # Second step: handle essential parts if needed.
+ if keep_essential_parts:
+ essential_cost, essential_matching = _handle_essential_parts(X, Y, order=order)
+ if (essential_cost == np.inf):
+ return _warn_infty(matching) # Tells the user that cost is infty and matching (if True) is None.
+ # avoid computing transport cost between the finite parts if essential parts
+ # cardinalities do not match (saves time)
+ else:
+ essential_cost = 0
+ essential_matching = None
+
+ # Now the standard pipeline for finite parts
if enable_autodiff:
import eagerpy as ep
@@ -139,6 +304,12 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
Y_orig = ep.astensor(Y)
X = X_orig.numpy()
Y = Y_orig.numpy()
+
+ # Extract finite points of the diagrams.
+ X, Y = _finite_part(X), _finite_part(Y)
+ n = len(X)
+ m = len(Y)
+
M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p)
a = np.ones(n+1) # weight vector of the input diagram. Uniform here.
a[-1] = m
@@ -154,7 +325,10 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
# Now we turn to -1 points encoding the diagonal
match[:,0][match[:,0] >= n] = -1
match[:,1][match[:,1] >= m] = -1
- return ot_cost ** (1./order) , match
+ # Finally incorporate the essential part matching
+ if essential_matching is not None:
+ match = np.concatenate([match, essential_matching]) if essential_matching.size else match
+ return (ot_cost + essential_cost) ** (1./order) , match
if enable_autodiff:
P = ot.emd(a=a, b=b, M=M, numItermax=2000000)
@@ -173,9 +347,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
return ep.concatenate(dists).norms.lp(order).raw
# We can also concatenate the 3 vectors to compute just one norm.
- # Comptuation of the otcost using the ot.emd2 library.
+ # Comptuation of the ot cost using the ot.emd2 library.
# Note: it is the Wasserstein distance to the power q.
# The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value?
ot_cost = ot.emd2(a, b, M, numItermax=2000000)
- return ot_cost ** (1./order)
+ return (ot_cost + essential_cost) ** (1./order)
diff --git a/src/python/gudhi/weighted_rips_complex.py b/src/python/gudhi/weighted_rips_complex.py
index 0541572b..16f63c3d 100644
--- a/src/python/gudhi/weighted_rips_complex.py
+++ b/src/python/gudhi/weighted_rips_complex.py
@@ -12,9 +12,11 @@ from gudhi import SimplexTree
class WeightedRipsComplex:
"""
Class to generate a weighted Rips complex from a distance matrix and weights on vertices,
- in the way described in :cite:`dtmfiltrations`.
+ in the way described in :cite:`dtmfiltrations` with `p=1`. The filtration value of vertex `i` is `2*weights[i]`,
+ and the filtration value of edge `ij` is `distance_matrix[i][j]+weights[i]+weights[j]`,
+ or the maximum of the filtrations of its extremities, whichever is largest.
Remark that all the filtration values are doubled compared to the definition in the paper
- for the consistency with RipsComplex.
+ for consistency with RipsComplex.
"""
def __init__(self,
distance_matrix,
diff --git a/src/python/include/Alpha_complex_factory.h b/src/python/include/Alpha_complex_factory.h
index 3405fdd6..41eb72c1 100644
--- a/src/python/include/Alpha_complex_factory.h
+++ b/src/python/include/Alpha_complex_factory.h
@@ -31,15 +31,34 @@ namespace Gudhi {
namespace alpha_complex {
-template <typename CgalPointType>
-std::vector<double> pt_cgal_to_cython(CgalPointType const& point) {
- std::vector<double> vd;
- vd.reserve(point.dimension());
- for (auto coord = point.cartesian_begin(); coord != point.cartesian_end(); coord++)
- vd.push_back(CGAL::to_double(*coord));
- return vd;
-}
+// template Functor that transforms a CGAL point to a vector of double as expected by cython
+template<typename CgalPointType, bool Weighted>
+struct Point_cgal_to_cython;
+
+// Specialized Unweighted Functor
+template<typename CgalPointType>
+struct Point_cgal_to_cython<CgalPointType, false> {
+ std::vector<double> operator()(CgalPointType const& point) const
+ {
+ std::vector<double> vd;
+ vd.reserve(point.dimension());
+ for (auto coord = point.cartesian_begin(); coord != point.cartesian_end(); coord++)
+ vd.push_back(CGAL::to_double(*coord));
+ return vd;
+ }
+};
+// Specialized Weighted Functor
+template<typename CgalPointType>
+struct Point_cgal_to_cython<CgalPointType, true> {
+ std::vector<double> operator()(CgalPointType const& weighted_point) const
+ {
+ const auto& point = weighted_point.point();
+ return Point_cgal_to_cython<decltype(point), false>()(point);
+ }
+};
+
+// Function that transforms a cython point (aka. a vector of double) to a CGAL point
template <typename CgalPointType>
static CgalPointType pt_cython_to_cgal(std::vector<double> const& vec) {
return CgalPointType(vec.size(), vec.begin(), vec.end());
@@ -51,24 +70,35 @@ class Abstract_alpha_complex {
virtual bool create_simplex_tree(Simplex_tree_interface<>* simplex_tree, double max_alpha_square,
bool default_filtration_value) = 0;
+
+ virtual std::size_t num_vertices() const = 0;
virtual ~Abstract_alpha_complex() = default;
};
-class Exact_Alphacomplex_dD final : public Abstract_alpha_complex {
+template <bool Weighted = false>
+class Exact_alpha_complex_dD final : public Abstract_alpha_complex {
private:
using Kernel = CGAL::Epeck_d<CGAL::Dynamic_dimension_tag>;
- using Point = typename Kernel::Point_d;
+ using Bare_point = typename Kernel::Point_d;
+ using Point = std::conditional_t<Weighted, typename Kernel::Weighted_point_d,
+ typename Kernel::Point_d>;
public:
- Exact_Alphacomplex_dD(const std::vector<std::vector<double>>& points, bool exact_version)
+ Exact_alpha_complex_dD(const std::vector<std::vector<double>>& points, bool exact_version)
+ : exact_version_(exact_version),
+ alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal<Bare_point>)) {
+ }
+
+ Exact_alpha_complex_dD(const std::vector<std::vector<double>>& points,
+ const std::vector<double>& weights, bool exact_version)
: exact_version_(exact_version),
- alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal<Point>)) {
+ alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal<Bare_point>), weights) {
}
virtual std::vector<double> get_point(int vh) override {
- Point const& point = alpha_complex_.get_point(vh);
- return pt_cgal_to_cython(point);
+ // Can be a Weighted or a Bare point in function of Weighted
+ return Point_cgal_to_cython<Point, Weighted>()(alpha_complex_.get_point(vh));
}
virtual bool create_simplex_tree(Simplex_tree_interface<>* simplex_tree, double max_alpha_square,
@@ -76,65 +106,49 @@ class Exact_Alphacomplex_dD final : public Abstract_alpha_complex {
return alpha_complex_.create_complex(*simplex_tree, max_alpha_square, exact_version_, default_filtration_value);
}
+ virtual std::size_t num_vertices() const override {
+ return alpha_complex_.num_vertices();
+ }
+
private:
bool exact_version_;
- Alpha_complex<Kernel> alpha_complex_;
+ Alpha_complex<Kernel, Weighted> alpha_complex_;
};
-class Inexact_Alphacomplex_dD final : public Abstract_alpha_complex {
+template <bool Weighted = false>
+class Inexact_alpha_complex_dD final : public Abstract_alpha_complex {
private:
using Kernel = CGAL::Epick_d<CGAL::Dynamic_dimension_tag>;
- using Point = typename Kernel::Point_d;
+ using Bare_point = typename Kernel::Point_d;
+ using Point = std::conditional_t<Weighted, typename Kernel::Weighted_point_d,
+ typename Kernel::Point_d>;
public:
- Inexact_Alphacomplex_dD(const std::vector<std::vector<double>>& points, bool exact_version)
- : exact_version_(exact_version),
- alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal<Point>)) {
+ Inexact_alpha_complex_dD(const std::vector<std::vector<double>>& points)
+ : alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal<Bare_point>)) {
+ }
+
+ Inexact_alpha_complex_dD(const std::vector<std::vector<double>>& points, const std::vector<double>& weights)
+ : alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal<Bare_point>), weights) {
}
virtual std::vector<double> get_point(int vh) override {
- Point const& point = alpha_complex_.get_point(vh);
- return pt_cgal_to_cython(point);
+ // Can be a Weighted or a Bare point in function of Weighted
+ return Point_cgal_to_cython<Point, Weighted>()(alpha_complex_.get_point(vh));
}
virtual bool create_simplex_tree(Simplex_tree_interface<>* simplex_tree, double max_alpha_square,
bool default_filtration_value) override {
- return alpha_complex_.create_complex(*simplex_tree, max_alpha_square, exact_version_, default_filtration_value);
+ return alpha_complex_.create_complex(*simplex_tree, max_alpha_square, false, default_filtration_value);
}
- private:
- bool exact_version_;
- Alpha_complex<Kernel> alpha_complex_;
-};
-
-template <complexity Complexity>
-class Alphacomplex_3D final : public Abstract_alpha_complex {
- private:
- using Point = typename Alpha_complex_3d<Complexity, false, false>::Bare_point_3;
-
- static Point pt_cython_to_cgal_3(std::vector<double> const& vec) {
- return Point(vec[0], vec[1], vec[2]);
- }
-
- public:
- Alphacomplex_3D(const std::vector<std::vector<double>>& points)
- : alpha_complex_(boost::adaptors::transform(points, pt_cython_to_cgal_3)) {
- }
-
- virtual std::vector<double> get_point(int vh) override {
- Point const& point = alpha_complex_.get_point(vh);
- return pt_cgal_to_cython(point);
- }
-
- virtual bool create_simplex_tree(Simplex_tree_interface<>* simplex_tree, double max_alpha_square,
- bool default_filtration_value) override {
- return alpha_complex_.create_complex(*simplex_tree, max_alpha_square);
+ virtual std::size_t num_vertices() const override {
+ return alpha_complex_.num_vertices();
}
private:
- Alpha_complex_3d<Complexity, false, false> alpha_complex_;
+ Alpha_complex<Kernel, Weighted> alpha_complex_;
};
-
} // namespace alpha_complex
} // namespace Gudhi
diff --git a/src/python/include/Alpha_complex_interface.h b/src/python/include/Alpha_complex_interface.h
index 23be194d..469b91ce 100644
--- a/src/python/include/Alpha_complex_interface.h
+++ b/src/python/include/Alpha_complex_interface.h
@@ -27,10 +27,23 @@ namespace alpha_complex {
class Alpha_complex_interface {
public:
- Alpha_complex_interface(const std::vector<std::vector<double>>& points, bool fast_version, bool exact_version)
- : points_(points),
- fast_version_(fast_version),
- exact_version_(exact_version) {
+ Alpha_complex_interface(const std::vector<std::vector<double>>& points,
+ const std::vector<double>& weights,
+ bool fast_version, bool exact_version) {
+ const bool weighted = (weights.size() > 0);
+ if (fast_version) {
+ if (weighted) {
+ alpha_ptr_ = std::make_unique<Inexact_alpha_complex_dD<true>>(points, weights);
+ } else {
+ alpha_ptr_ = std::make_unique<Inexact_alpha_complex_dD<false>>(points);
+ }
+ } else {
+ if (weighted) {
+ alpha_ptr_ = std::make_unique<Exact_alpha_complex_dD<true>>(points, weights, exact_version);
+ } else {
+ alpha_ptr_ = std::make_unique<Exact_alpha_complex_dD<false>>(points, exact_version);
+ }
+ }
}
std::vector<double> get_point(int vh) {
@@ -39,38 +52,23 @@ class Alpha_complex_interface {
void create_simplex_tree(Simplex_tree_interface<>* simplex_tree, double max_alpha_square,
bool default_filtration_value) {
- if (points_.size() > 0) {
- std::size_t dimension = points_[0].size();
- if (dimension == 3 && !default_filtration_value) {
- if (fast_version_)
- alpha_ptr_ = std::make_unique<Alphacomplex_3D<Gudhi::alpha_complex::complexity::FAST>>(points_);
- else if (exact_version_)
- alpha_ptr_ = std::make_unique<Alphacomplex_3D<Gudhi::alpha_complex::complexity::EXACT>>(points_);
- else
- alpha_ptr_ = std::make_unique<Alphacomplex_3D<Gudhi::alpha_complex::complexity::SAFE>>(points_);
- if (!alpha_ptr_->create_simplex_tree(simplex_tree, max_alpha_square, default_filtration_value)) {
- // create_simplex_tree will fail if all points are on a plane - Retry with dD by setting dimension to 2
- dimension--;
- alpha_ptr_.reset();
- }
- }
- // Not ** else ** because we have to take into account if 3d fails
- if (dimension != 3 || default_filtration_value) {
- if (fast_version_) {
- alpha_ptr_ = std::make_unique<Inexact_Alphacomplex_dD>(points_, exact_version_);
- } else {
- alpha_ptr_ = std::make_unique<Exact_Alphacomplex_dD>(points_, exact_version_);
- }
- alpha_ptr_->create_simplex_tree(simplex_tree, max_alpha_square, default_filtration_value);
- }
- }
+ // Nothing to be done in case of an empty point set
+ if (alpha_ptr_->num_vertices() > 0)
+ alpha_ptr_->create_simplex_tree(simplex_tree, max_alpha_square, default_filtration_value);
+ }
+
+ static void set_float_relative_precision(double precision) {
+ // cf. Exact_alpha_complex_dD kernel type in Alpha_complex_factory.h
+ CGAL::Epeck_d<CGAL::Dynamic_dimension_tag>::FT::set_relative_precision_of_to_double(precision);
+ }
+
+ static double get_float_relative_precision() {
+ // cf. Exact_alpha_complex_dD kernel type in Alpha_complex_factory.h
+ return CGAL::Epeck_d<CGAL::Dynamic_dimension_tag>::FT::get_relative_precision_of_to_double();
}
private:
std::unique_ptr<Abstract_alpha_complex> alpha_ptr_;
- std::vector<std::vector<double>> points_;
- bool fast_version_;
- bool exact_version_;
};
} // namespace alpha_complex
diff --git a/src/python/include/Persistent_cohomology_interface.h b/src/python/include/Persistent_cohomology_interface.h
index e5a3dfba..945378a0 100644
--- a/src/python/include/Persistent_cohomology_interface.h
+++ b/src/python/include/Persistent_cohomology_interface.h
@@ -12,6 +12,8 @@
#define INCLUDE_PERSISTENT_COHOMOLOGY_INTERFACE_H_
#include <gudhi/Persistent_cohomology.h>
+#include <gudhi/Simplex_tree.h> // for Extended_simplex_type
+
#include <cstdlib>
#include <vector>
@@ -223,6 +225,44 @@ persistent_cohomology::Persistent_cohomology<FilteredComplex, persistent_cohomol
return out;
}
+ using Filtration_value = typename FilteredComplex::Filtration_value;
+ using Birth_death = std::pair<Filtration_value, Filtration_value>;
+ using Persistence_subdiagrams = std::vector<std::vector<std::pair<int, Birth_death>>>;
+
+ Persistence_subdiagrams compute_extended_persistence_subdiagrams(Filtration_value min_persistence){
+ Persistence_subdiagrams pers_subs(4);
+ auto const& persistent_pairs = Base::get_persistent_pairs();
+ for (auto pair : persistent_pairs) {
+ std::pair<Filtration_value, Extended_simplex_type> px = stptr_->decode_extended_filtration(stptr_->filtration(get<0>(pair)),
+ stptr_->efd);
+ std::pair<Filtration_value, Extended_simplex_type> py = stptr_->decode_extended_filtration(stptr_->filtration(get<1>(pair)),
+ stptr_->efd);
+ std::pair<int, Birth_death> pd_point = std::make_pair(stptr_->dimension(get<0>(pair)),
+ std::make_pair(px.first, py.first));
+ if(std::abs(px.first - py.first) > min_persistence){
+ //Ordinary
+ if (px.second == Extended_simplex_type::UP && py.second == Extended_simplex_type::UP){
+ pers_subs[0].push_back(pd_point);
+ }
+ // Relative
+ else if (px.second == Extended_simplex_type::DOWN && py.second == Extended_simplex_type::DOWN){
+ pers_subs[1].push_back(pd_point);
+ }
+ else{
+ // Extended+
+ if (px.first < py.first){
+ pers_subs[2].push_back(pd_point);
+ }
+ //Extended-
+ else{
+ pers_subs[3].push_back(pd_point);
+ }
+ }
+ }
+ }
+ return pers_subs;
+ }
+
private:
// A copy
FilteredComplex* stptr_;
diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h
index 629f6083..0317ea39 100644
--- a/src/python/include/Simplex_tree_interface.h
+++ b/src/python/include/Simplex_tree_interface.h
@@ -15,9 +15,7 @@
#include <gudhi/distance_functions.h>
#include <gudhi/Simplex_tree.h>
#include <gudhi/Points_off_io.h>
-#ifdef GUDHI_USE_EIGEN3
#include <gudhi/Flag_complex_edge_collapser.h>
-#endif
#include <iostream>
#include <vector>
@@ -42,6 +40,9 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
using Complex_simplex_iterator = typename Base::Complex_simplex_iterator;
using Extended_filtration_data = typename Base::Extended_filtration_data;
using Boundary_simplex_iterator = typename Base::Boundary_simplex_iterator;
+ using Siblings = typename Base::Siblings;
+ using Node = typename Base::Node;
+ typedef bool (*blocker_func_t)(Simplex simplex, void *user_data);
public:
@@ -63,6 +64,30 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
return (result.second);
}
+ void insert_matrix(double* filtrations, int n, int stride0, int stride1, double max_filtration) {
+ // We could delegate to insert_graph, but wrapping the matrix in a graph interface is too much work,
+ // and this is a bit more efficient.
+ auto& rm = this->root()->members_;
+ for(int i=0; i<n; ++i) {
+ char* p = reinterpret_cast<char*>(filtrations) + i * stride0;
+ double fv = *reinterpret_cast<double*>(p + i * stride1);
+ if(fv > max_filtration) continue;
+ auto sh = rm.emplace_hint(rm.end(), i, Node(this->root(), fv));
+ Siblings* children = nullptr;
+ // Should we make a first pass to count the number of edges so we can reserve the right space?
+ for(int j=i+1; j<n; ++j) {
+ double fe = *reinterpret_cast<double*>(p + j * stride1);
+ if(fe > max_filtration) continue;
+ if(!children) {
+ children = new Siblings(this->root(), i);
+ sh->second.assign_children(children);
+ }
+ children->members().emplace_hint(children->members().end(), j, Node(children, fe));
+ }
+ }
+
+ }
+
// Do not interface this function, only used in alpha complex interface for complex creation
bool insert_simplex(const Simplex& simplex, Filtration_value filtration = 0) {
Insertion_result result = Base::insert_simplex(simplex, filtration);
@@ -133,38 +158,7 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
return;
}
- std::vector<std::vector<std::pair<int, std::pair<Filtration_value, Filtration_value>>>> compute_extended_persistence_subdiagrams(const std::vector<std::pair<int, std::pair<Filtration_value, Filtration_value>>>& dgm, Filtration_value min_persistence){
- std::vector<std::vector<std::pair<int, std::pair<Filtration_value, Filtration_value>>>> new_dgm(4);
- for (unsigned int i = 0; i < dgm.size(); i++){
- std::pair<Filtration_value, Extended_simplex_type> px = this->decode_extended_filtration(dgm[i].second.first, this->efd);
- std::pair<Filtration_value, Extended_simplex_type> py = this->decode_extended_filtration(dgm[i].second.second, this->efd);
- std::pair<int, std::pair<Filtration_value, Filtration_value>> pd_point = std::make_pair(dgm[i].first, std::make_pair(px.first, py.first));
- if(std::abs(px.first - py.first) > min_persistence){
- //Ordinary
- if (px.second == Extended_simplex_type::UP && py.second == Extended_simplex_type::UP){
- new_dgm[0].push_back(pd_point);
- }
- // Relative
- else if (px.second == Extended_simplex_type::DOWN && py.second == Extended_simplex_type::DOWN){
- new_dgm[1].push_back(pd_point);
- }
- else{
- // Extended+
- if (px.first < py.first){
- new_dgm[2].push_back(pd_point);
- }
- //Extended-
- else{
- new_dgm[3].push_back(pd_point);
- }
- }
- }
- }
- return new_dgm;
- }
-
Simplex_tree_interface* collapse_edges(int nb_collapse_iteration) {
-#ifdef GUDHI_USE_EIGEN3
using Filtered_edge = std::tuple<Vertex_handle, Vertex_handle, Filtration_value>;
std::vector<Filtered_edge> edges;
for (Simplex_handle sh : Base::skeleton_simplex_range(1)) {
@@ -178,7 +172,7 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
}
for (int iteration = 0; iteration < nb_collapse_iteration; iteration++) {
- edges = Gudhi::collapse::flag_complex_collapse_edges(edges);
+ edges = Gudhi::collapse::flag_complex_collapse_edges(std::move(edges));
}
Simplex_tree_interface* collapsed_stree_ptr = new Simplex_tree_interface();
// Copy the original 0-skeleton
@@ -190,9 +184,13 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
collapsed_stree_ptr->insert({std::get<0>(remaining_edge), std::get<1>(remaining_edge)}, std::get<2>(remaining_edge));
}
return collapsed_stree_ptr;
-#else
- throw std::runtime_error("Unable to collapse edges as it requires Eigen3 >= 3.1.0.");
-#endif
+ }
+
+ void expansion_with_blockers_callback(int dimension, blocker_func_t user_func, void *user_data) {
+ Base::expansion_with_blockers(dimension, [&](Simplex_handle sh){
+ Simplex simplex(Base::simplex_vertex_range(sh).begin(), Base::simplex_vertex_range(sh).end());
+ return user_func(simplex, user_data);
+ });
}
// Iterator over the simplex tree
diff --git a/src/python/include/pybind11_diagram_utils.h b/src/python/include/pybind11_diagram_utils.h
index 2d5194f4..5cb7c48b 100644
--- a/src/python/include/pybind11_diagram_utils.h
+++ b/src/python/include/pybind11_diagram_utils.h
@@ -17,16 +17,9 @@
namespace py = pybind11;
typedef py::array_t<double> Dgm;
-// Get m[i,0] and m[i,1] as a pair
-static auto pairify(void* p, py::ssize_t h, py::ssize_t w) {
- return [=](py::ssize_t i){
- char* birth = (char*)p + i * h;
- char* death = birth + w;
- return std::make_pair(*(double*)birth, *(double*)death);
- };
-}
-
-inline auto numpy_to_range_of_pairs(py::array_t<double> dgm) {
+// build_point(double birth, double death, ssize_t index) -> Point
+template<class BuildPoint>
+inline auto numpy_to_range_of_pairs(py::array_t<double> dgm, BuildPoint build_point) {
py::buffer_info buf = dgm.request();
// shape (n,2) or (0) for empty
if((buf.ndim!=2 || buf.shape[1]!=2) && (buf.ndim!=1 || buf.shape[0]!=0))
@@ -34,6 +27,16 @@ inline auto numpy_to_range_of_pairs(py::array_t<double> dgm) {
// In the case of shape (0), avoid reading non-existing strides[1] even if we won't use it.
py::ssize_t stride1 = buf.ndim == 2 ? buf.strides[1] : 0;
auto cnt = boost::counting_range<py::ssize_t>(0, buf.shape[0]);
- return boost::adaptors::transform(cnt, pairify(buf.ptr, buf.strides[0], stride1));
+
+ char* p = static_cast<char*>(buf.ptr);
+ auto h = buf.strides[0];
+ auto w = stride1;
+ // Get m[i,0] and m[i,1] as a pair
+ auto pairify = [=](py::ssize_t i){
+ char* birth = p + i * h;
+ char* death = birth + w;
+ return build_point(*(double*)birth, *(double*)death, i);
+ };
+ return boost::adaptors::transform(cnt, pairify);
// Be careful that the returned range cannot contain references to dead temporaries.
}
diff --git a/src/python/pyproject.toml b/src/python/pyproject.toml
new file mode 100644
index 00000000..55b64466
--- /dev/null
+++ b/src/python/pyproject.toml
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools>=24.2.0", "wheel", "numpy>=1.15.0", "cython>=0.27", "pybind11"]
+build-backend = "setuptools.build_meta"
diff --git a/src/python/setup.py.in b/src/python/setup.py.in
index 98d058fc..6eb0db42 100644
--- a/src/python/setup.py.in
+++ b/src/python/setup.py.in
@@ -5,6 +5,7 @@
Copyright (C) 2019 Inria
Modification(s):
+ - 2021/12 Vincent Rouvreau: Python 3.5 as minimal version
- YYYY/MM Author: Description of the modification
"""
@@ -41,17 +42,12 @@ for module in cython_modules:
libraries=libraries,
library_dirs=library_dirs,
include_dirs=include_dirs,
- runtime_library_dirs=runtime_library_dirs,
- cython_directives = {'language_level': str(sys.version_info[0])},))
+ runtime_library_dirs=runtime_library_dirs,))
-ext_modules = cythonize(ext_modules)
+ext_modules = cythonize(ext_modules, compiler_directives={'language_level': '3'})
for module in pybind11_modules:
my_include_dirs = include_dirs + [pybind11.get_include(False), pybind11.get_include(True)]
- if module == 'hera/wasserstein':
- my_include_dirs = ['@HERA_WASSERSTEIN_INCLUDE_DIR@'] + my_include_dirs
- elif module == 'hera/bottleneck':
- my_include_dirs = ['@HERA_BOTTLENECK_INCLUDE_DIR@'] + my_include_dirs
ext_modules.append(Extension(
'gudhi.' + module.replace('/', '.'),
sources = [source_dir + module + '.cc'],
@@ -72,7 +68,7 @@ setup(
name = 'gudhi',
packages=find_packages(), # find_namespace_packages(include=["gudhi*"])
author='GUDHI Editorial Board',
- author_email='gudhi-contact@lists.gforge.inria.fr',
+ author_email='gudhi-contact@inria.fr',
version='@GUDHI_VERSION@',
url='https://gudhi.inria.fr/',
project_urls={
@@ -83,10 +79,11 @@ setup(
},
description='The Gudhi library is an open source library for ' \
'Computational Topology and Topological Data Analysis (TDA).',
+ data_files=[('.', ['./introduction.rst'])],
long_description_content_type='text/x-rst',
long_description=long_description,
ext_modules = ext_modules,
- install_requires = ['numpy >= 1.9',],
- setup_requires = ['cython','numpy >= 1.9','pybind11',],
+ python_requires='>=3.5.0',
+ install_requires = ['numpy >= 1.15.0',],
package_data={"": ["*.dll"], },
)
diff --git a/src/python/test/test_alpha_complex.py b/src/python/test/test_alpha_complex.py
index 814f8289..f81e6137 100755
--- a/src/python/test/test_alpha_complex.py
+++ b/src/python/test/test_alpha_complex.py
@@ -8,10 +8,12 @@
- YYYY/MM Author: Description of the modification
"""
-import gudhi as gd
+from gudhi import AlphaComplex
import math
import numpy as np
import pytest
+import warnings
+
try:
# python3
from itertools import zip_longest
@@ -19,22 +21,24 @@ except ImportError:
# python2
from itertools import izip_longest as zip_longest
-__author__ = "Vincent Rouvreau"
-__copyright__ = "Copyright (C) 2016 Inria"
-__license__ = "MIT"
def _empty_alpha(precision):
- alpha_complex = gd.AlphaComplex(points=[[0, 0]], precision = precision)
+ alpha_complex = AlphaComplex(precision = precision)
+ assert alpha_complex.__is_defined() == True
+
+def _one_2d_point_alpha(precision):
+ alpha_complex = AlphaComplex(points=[[0, 0]], precision = precision)
assert alpha_complex.__is_defined() == True
def test_empty_alpha():
for precision in ['fast', 'safe', 'exact']:
_empty_alpha(precision)
+ _one_2d_point_alpha(precision)
def _infinite_alpha(precision):
point_list = [[0, 0], [1, 0], [0, 1], [1, 1]]
- alpha_complex = gd.AlphaComplex(points=point_list, precision = precision)
+ alpha_complex = AlphaComplex(points=point_list, precision = precision)
assert alpha_complex.__is_defined() == True
simplex_tree = alpha_complex.create_simplex_tree()
@@ -69,18 +73,9 @@ def _infinite_alpha(precision):
assert point_list[1] == alpha_complex.get_point(1)
assert point_list[2] == alpha_complex.get_point(2)
assert point_list[3] == alpha_complex.get_point(3)
- try:
- alpha_complex.get_point(4) == []
- except IndexError:
- pass
- else:
- assert False
- try:
- alpha_complex.get_point(125) == []
- except IndexError:
- pass
- else:
- assert False
+
+ with pytest.raises(IndexError):
+ alpha_complex.get_point(len(point_list))
def test_infinite_alpha():
for precision in ['fast', 'safe', 'exact']:
@@ -88,7 +83,7 @@ def test_infinite_alpha():
def _filtered_alpha(precision):
point_list = [[0, 0], [1, 0], [0, 1], [1, 1]]
- filtered_alpha = gd.AlphaComplex(points=point_list, precision = precision)
+ filtered_alpha = AlphaComplex(points=point_list, precision = precision)
simplex_tree = filtered_alpha.create_simplex_tree(max_alpha_square=0.25)
@@ -99,18 +94,9 @@ def _filtered_alpha(precision):
assert point_list[1] == filtered_alpha.get_point(1)
assert point_list[2] == filtered_alpha.get_point(2)
assert point_list[3] == filtered_alpha.get_point(3)
- try:
- filtered_alpha.get_point(4) == []
- except IndexError:
- pass
- else:
- assert False
- try:
- filtered_alpha.get_point(125) == []
- except IndexError:
- pass
- else:
- assert False
+
+ with pytest.raises(IndexError):
+ filtered_alpha.get_point(len(point_list))
assert list(simplex_tree.get_filtration()) == [
([0], 0.0),
@@ -141,10 +127,10 @@ def _safe_alpha_persistence_comparison(precision):
embedding2 = [[signal[i], delayed[i]] for i in range(len(time))]
#build alpha complex and simplex tree
- alpha_complex1 = gd.AlphaComplex(points=embedding1, precision = precision)
+ alpha_complex1 = AlphaComplex(points=embedding1, precision = precision)
simplex_tree1 = alpha_complex1.create_simplex_tree()
- alpha_complex2 = gd.AlphaComplex(points=embedding2, precision = precision)
+ alpha_complex2 = AlphaComplex(points=embedding2, precision = precision)
simplex_tree2 = alpha_complex2.create_simplex_tree()
diag1 = simplex_tree1.persistence()
@@ -162,7 +148,7 @@ def test_safe_alpha_persistence_comparison():
def _delaunay_complex(precision):
point_list = [[0, 0], [1, 0], [0, 1], [1, 1]]
- filtered_alpha = gd.AlphaComplex(points=point_list, precision = precision)
+ filtered_alpha = AlphaComplex(points=point_list, precision = precision)
simplex_tree = filtered_alpha.create_simplex_tree(default_filtration_value = True)
@@ -173,18 +159,11 @@ def _delaunay_complex(precision):
assert point_list[1] == filtered_alpha.get_point(1)
assert point_list[2] == filtered_alpha.get_point(2)
assert point_list[3] == filtered_alpha.get_point(3)
- try:
- filtered_alpha.get_point(4) == []
- except IndexError:
- pass
- else:
- assert False
- try:
- filtered_alpha.get_point(125) == []
- except IndexError:
- pass
- else:
- assert False
+
+ with pytest.raises(IndexError):
+ filtered_alpha.get_point(4)
+ with pytest.raises(IndexError):
+ filtered_alpha.get_point(125)
for filtered_value in simplex_tree.get_filtration():
assert math.isnan(filtered_value[1])
@@ -198,7 +177,13 @@ def test_delaunay_complex():
_delaunay_complex(precision)
def _3d_points_on_a_plane(precision, default_filtration_value):
- alpha = gd.AlphaComplex(off_file='alphacomplexdoc.off', precision = precision)
+ alpha = AlphaComplex(points = [[1.0, 1.0 , 0.0],
+ [7.0, 0.0 , 0.0],
+ [4.0, 6.0 , 0.0],
+ [9.0, 6.0 , 0.0],
+ [0.0, 14.0, 0.0],
+ [2.0, 19.0, 0.0],
+ [9.0, 17.0, 0.0]], precision = precision)
simplex_tree = alpha.create_simplex_tree(default_filtration_value = default_filtration_value)
assert simplex_tree.dimension() == 2
@@ -206,28 +191,16 @@ def _3d_points_on_a_plane(precision, default_filtration_value):
assert simplex_tree.num_simplices() == 25
def test_3d_points_on_a_plane():
- off_file = open("alphacomplexdoc.off", "w")
- off_file.write("OFF \n" \
- "7 0 0 \n" \
- "1.0 1.0 0.0\n" \
- "7.0 0.0 0.0\n" \
- "4.0 6.0 0.0\n" \
- "9.0 6.0 0.0\n" \
- "0.0 14.0 0.0\n" \
- "2.0 19.0 0.0\n" \
- "9.0 17.0 0.0\n" )
- off_file.close()
-
for default_filtration_value in [True, False]:
for precision in ['fast', 'safe', 'exact']:
_3d_points_on_a_plane(precision, default_filtration_value)
def _3d_tetrahedrons(precision):
points = 10*np.random.rand(10, 3)
- alpha = gd.AlphaComplex(points=points, precision = precision)
+ alpha = AlphaComplex(points = points, precision = precision)
st_alpha = alpha.create_simplex_tree(default_filtration_value = False)
# New AlphaComplex for get_point to work
- delaunay = gd.AlphaComplex(points=points, precision = precision)
+ delaunay = AlphaComplex(points = points, precision = precision)
st_delaunay = delaunay.create_simplex_tree(default_filtration_value = True)
delaunay_tetra = []
@@ -256,3 +229,87 @@ def _3d_tetrahedrons(precision):
def test_3d_tetrahedrons():
for precision in ['fast', 'safe', 'exact']:
_3d_tetrahedrons(precision)
+
+def test_off_file_deprecation_warning():
+ off_file = open("alphacomplexdoc.off", "w")
+ off_file.write("OFF \n" \
+ "7 0 0 \n" \
+ "1.0 1.0 0.0\n" \
+ "7.0 0.0 0.0\n" \
+ "4.0 6.0 0.0\n" \
+ "9.0 6.0 0.0\n" \
+ "0.0 14.0 0.0\n" \
+ "2.0 19.0 0.0\n" \
+ "9.0 17.0 0.0\n" )
+ off_file.close()
+
+ with pytest.warns(DeprecationWarning):
+ alpha = AlphaComplex(off_file="alphacomplexdoc.off")
+
+def test_non_existing_off_file():
+ with pytest.warns(DeprecationWarning):
+ with pytest.raises(FileNotFoundError):
+ alpha = AlphaComplex(off_file="pouetpouettralala.toubiloubabdou")
+
+def test_inconsistency_points_and_weights():
+ points = [[1.0, 1.0 , 0.0],
+ [7.0, 0.0 , 0.0],
+ [4.0, 6.0 , 0.0],
+ [9.0, 6.0 , 0.0],
+ [0.0, 14.0, 0.0],
+ [2.0, 19.0, 0.0],
+ [9.0, 17.0, 0.0]]
+ with pytest.raises(ValueError):
+ # 7 points, 8 weights, on purpose
+ alpha = AlphaComplex(points = points,
+ weights = [1., 2., 3., 4., 5., 6., 7., 8.])
+
+ with pytest.raises(ValueError):
+ # 7 points, 6 weights, on purpose
+ alpha = AlphaComplex(points = points,
+ weights = [1., 2., 3., 4., 5., 6.])
+
+def _weighted_doc_example(precision):
+ stree = AlphaComplex(points=[[ 1., -1., -1.],
+ [-1., 1., -1.],
+ [-1., -1., 1.],
+ [ 1., 1., 1.],
+ [ 2., 2., 2.]],
+ weights = [4., 4., 4., 4., 1.],
+ precision = precision).create_simplex_tree()
+
+ assert stree.filtration([0, 1, 2, 3]) == pytest.approx(-1.)
+ assert stree.filtration([0, 1, 3, 4]) == pytest.approx(95.)
+ assert stree.filtration([0, 2, 3, 4]) == pytest.approx(95.)
+ assert stree.filtration([1, 2, 3, 4]) == pytest.approx(95.)
+
+def test_weighted_doc_example():
+ for precision in ['fast', 'safe', 'exact']:
+ _weighted_doc_example(precision)
+
+def test_float_relative_precision():
+ assert AlphaComplex.get_float_relative_precision() == 1e-5
+ # Must be > 0.
+ with pytest.raises(ValueError):
+ AlphaComplex.set_float_relative_precision(0.)
+ # Must be < 1.
+ with pytest.raises(ValueError):
+ AlphaComplex.set_float_relative_precision(1.)
+
+ points = [[1, 1], [7, 0], [4, 6], [9, 6], [0, 14], [2, 19], [9, 17]]
+ st = AlphaComplex(points=points).create_simplex_tree()
+ filtrations = list(st.get_filtration())
+
+ # Get a better precision
+ AlphaComplex.set_float_relative_precision(1e-15)
+ assert AlphaComplex.get_float_relative_precision() == 1e-15
+
+ st = AlphaComplex(points=points).create_simplex_tree()
+ filtrations_better_resolution = list(st.get_filtration())
+
+ assert len(filtrations) == len(filtrations_better_resolution)
+ for idx in range(len(filtrations)):
+ # check simplex is the same
+ assert filtrations[idx][0] == filtrations_better_resolution[idx][0]
+ # check filtration is about the same with a relative precision of the worst case
+ assert filtrations[idx][1] == pytest.approx(filtrations_better_resolution[idx][1], rel=1e-5)
diff --git a/src/python/test/test_betti_curve_representations.py b/src/python/test/test_betti_curve_representations.py
new file mode 100755
index 00000000..6a45da4d
--- /dev/null
+++ b/src/python/test/test_betti_curve_representations.py
@@ -0,0 +1,59 @@
+import numpy as np
+import scipy.interpolate
+import pytest
+
+from gudhi.representations.vector_methods import BettiCurve
+
+def test_betti_curve_is_irregular_betti_curve_followed_by_interpolation():
+ m = 10
+ n = 1000
+ pinf = 0.05
+ pzero = 0.05
+ res = 100
+
+ pds = []
+ for i in range(0, m):
+ pd = np.zeros((n, 2))
+ pd[:, 0] = np.random.uniform(0, 10, n)
+ pd[:, 1] = np.random.uniform(pd[:, 0], 10, n)
+ pd[np.random.uniform(0, 1, n) < pzero, 0] = 0
+ pd[np.random.uniform(0, 1, n) < pinf, 1] = np.inf
+ pds.append(pd)
+
+ bc = BettiCurve(resolution=None, predefined_grid=None)
+ bc.fit(pds)
+ bettis = bc.transform(pds)
+
+ bc2 = BettiCurve(resolution=None, predefined_grid=None)
+ bettis2 = bc2.fit_transform(pds)
+ assert((bc2.grid_ == bc.grid_).all())
+ assert((bettis2 == bettis).all())
+
+ for i in range(0, m):
+ grid = np.linspace(pds[i][np.isfinite(pds[i])].min(), pds[i][np.isfinite(pds[i])].max() + 1, res)
+ bc_gridded = BettiCurve(predefined_grid=grid)
+ bc_gridded.fit([])
+ bettis_gridded = bc_gridded(pds[i])
+
+ interp = scipy.interpolate.interp1d(bc.grid_, bettis[i, :], kind="previous", fill_value="extrapolate")
+ bettis_interp = np.array(interp(grid), dtype=int)
+ assert((bettis_interp == bettis_gridded).all())
+
+
+def test_empty_with_predefined_grid():
+ random_grid = np.sort(np.random.uniform(0, 1, 100))
+ bc = BettiCurve(predefined_grid=random_grid)
+ bettis = bc.fit_transform([])
+ assert((bc.grid_ == random_grid).all())
+ assert((bettis == 0).all())
+
+
+def test_empty():
+ bc = BettiCurve(resolution=None, predefined_grid=None)
+ bettis = bc.fit_transform([])
+ assert(bc.grid_ == [-np.inf])
+ assert((bettis == 0).all())
+
+def test_wrong_value_of_predefined_grid():
+ with pytest.raises(ValueError):
+ BettiCurve(predefined_grid=[1, 2, 3])
diff --git a/src/python/test/test_cubical_complex.py b/src/python/test/test_cubical_complex.py
index d0e4e9e8..29d559b3 100755
--- a/src/python/test/test_cubical_complex.py
+++ b/src/python/test/test_cubical_complex.py
@@ -174,3 +174,28 @@ def test_periodic_cofaces_of_persistence_pairs_when_pd_has_no_paired_birth_and_d
assert np.array_equal(pairs[1][0], np.array([0]))
assert np.array_equal(pairs[1][1], np.array([0, 1]))
assert np.array_equal(pairs[1][2], np.array([1]))
+
+def test_cubical_persistence_intervals_in_dimension():
+ cub = CubicalComplex(
+ dimensions=[3, 3],
+ top_dimensional_cells=[1, 2, 3, 4, 5, 6, 7, 8, 9],
+ )
+ cub.compute_persistence()
+ H0 = cub.persistence_intervals_in_dimension(0)
+ assert np.array_equal(H0, np.array([[ 1., float("inf")]]))
+ assert cub.persistence_intervals_in_dimension(1).shape == (0, 2)
+
+def test_periodic_cubical_persistence_intervals_in_dimension():
+ cub = PeriodicCubicalComplex(
+ dimensions=[3, 3],
+ top_dimensional_cells=[1, 2, 3, 4, 5, 6, 7, 8, 9],
+ periodic_dimensions = [True, True]
+ )
+ cub.compute_persistence()
+ H0 = cub.persistence_intervals_in_dimension(0)
+ assert np.array_equal(H0, np.array([[ 1., float("inf")]]))
+ H1 = cub.persistence_intervals_in_dimension(1)
+ assert np.array_equal(H1, np.array([[ 3., float("inf")], [ 7., float("inf")]]))
+ H2 = cub.persistence_intervals_in_dimension(2)
+ assert np.array_equal(H2, np.array([[ 9., float("inf")]]))
+ assert cub.persistence_intervals_in_dimension(3).shape == (0, 2)
diff --git a/src/python/test/test_datasets_generators.py b/src/python/test/test_datasets_generators.py
new file mode 100755
index 00000000..91ec4a65
--- /dev/null
+++ b/src/python/test/test_datasets_generators.py
@@ -0,0 +1,39 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Hind Montassif
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.datasets.generators import points
+
+import pytest
+
+def test_sphere():
+ assert points.sphere(n_samples = 10, ambient_dim = 2, radius = 1., sample = 'random').shape == (10, 2)
+
+ with pytest.raises(ValueError):
+ points.sphere(n_samples = 10, ambient_dim = 2, radius = 1., sample = 'other')
+
+def _basic_torus(impl):
+ assert impl(n_samples = 64, dim = 3, sample = 'random').shape == (64, 6)
+ assert impl(n_samples = 64, dim = 3, sample = 'grid').shape == (64, 6)
+
+ assert impl(n_samples = 10, dim = 4, sample = 'random').shape == (10, 8)
+
+ # Here 1**dim < n_samples < 2**dim, the output shape is therefore (1, 2*dim) = (1, 8), where shape[0] is rounded down to the closest perfect 'dim'th power
+ assert impl(n_samples = 10, dim = 4, sample = 'grid').shape == (1, 8)
+
+ with pytest.raises(ValueError):
+ impl(n_samples = 10, dim = 4, sample = 'other')
+
+def test_torus():
+ for torus_impl in [points.torus, points.ctorus]:
+ _basic_torus(torus_impl)
+ # Check that the two versions (torus and ctorus) generate the same output
+ assert points.ctorus(n_samples = 64, dim = 3, sample = 'random').all() == points.torus(n_samples = 64, dim = 3, sample = 'random').all()
+ assert points.ctorus(n_samples = 64, dim = 3, sample = 'grid').all() == points.torus(n_samples = 64, dim = 3, sample = 'grid').all()
+ assert points.ctorus(n_samples = 10, dim = 3, sample = 'grid').all() == points.torus(n_samples = 10, dim = 3, sample = 'grid').all()
diff --git a/src/python/test/test_diff.py b/src/python/test/test_diff.py
new file mode 100644
index 00000000..dca001a9
--- /dev/null
+++ b/src/python/test/test_diff.py
@@ -0,0 +1,78 @@
+from gudhi.tensorflow import *
+import numpy as np
+import tensorflow as tf
+import gudhi as gd
+
+def test_rips_diff():
+
+ Xinit = np.array([[1.,1.],[2.,2.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ rl = RipsLayer(maximum_edge_length=2., homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = rl.call(X)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [X])
+ assert tf.norm(grads[0]-tf.constant([[-.5,-.5],[.5,.5]]),1) <= 1e-6
+
+def test_cubical_diff():
+
+ Xinit = np.array([[0.,2.,2.],[2.,2.,2.],[2.,2.,1.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ cl = CubicalLayer(homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = cl.call(X)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [X])
+ assert tf.norm(grads[0]-tf.constant([[0.,0.,0.],[0.,.5,0.],[0.,0.,-.5]]),1) <= 1e-6
+
+def test_nonsquare_cubical_diff():
+
+ Xinit = np.array([[-1.,1.,0.],[1.,1.,1.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ cl = CubicalLayer(homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = cl.call(X)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [X])
+ assert tf.norm(grads[0]-tf.constant([[0.,0.5,-0.5],[0.,0.,0.]]),1) <= 1e-6
+
+def test_st_diff():
+
+ st = gd.SimplexTree()
+ st.insert([0])
+ st.insert([1])
+ st.insert([2])
+ st.insert([3])
+ st.insert([4])
+ st.insert([5])
+ st.insert([6])
+ st.insert([7])
+ st.insert([8])
+ st.insert([9])
+ st.insert([10])
+ st.insert([0, 1])
+ st.insert([1, 2])
+ st.insert([2, 3])
+ st.insert([3, 4])
+ st.insert([4, 5])
+ st.insert([5, 6])
+ st.insert([6, 7])
+ st.insert([7, 8])
+ st.insert([8, 9])
+ st.insert([9, 10])
+
+ Finit = np.array([6.,4.,3.,4.,5.,4.,3.,2.,3.,4.,5.], dtype=np.float32)
+ F = tf.Variable(initial_value=Finit, trainable=True)
+ sl = LowerStarSimplexTreeLayer(simplextree=st, homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = sl.call(F)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [F])
+
+ assert tf.math.reduce_all(tf.math.equal(grads[0].indices, tf.constant([2,4])))
+ assert tf.math.reduce_all(tf.math.equal(grads[0].values, tf.constant([-1.,1.])))
+
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index 0a52279e..b276f041 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -13,6 +13,7 @@ import numpy
import pytest
import torch
import math
+import warnings
def test_dtm_compare_euclidean():
@@ -87,3 +88,14 @@ def test_density():
assert density == pytest.approx(expected)
density = DTMDensity(weights=[0.5, 0.5], metric="neighbors", dim=1).fit_transform(distances)
assert density == pytest.approx(expected)
+
+def test_dtm_overflow_warnings():
+ pts = numpy.array([[10., 100000000000000000000000000000.], [1000., 100000000000000000000000000.]])
+ impl_warn = ["keops", "hnsw"]
+ for impl in impl_warn:
+ with warnings.catch_warnings(record=True) as w:
+ dtm = DistanceToMeasure(2, implementation=impl)
+ r = dtm.fit_transform(pts)
+ assert len(w) == 1
+ assert issubclass(w[0].category, RuntimeWarning)
+ assert "Overflow" in str(w[0].message)
diff --git a/src/python/test/test_off.py b/src/python/test/test_off.py
new file mode 100644
index 00000000..aea1941b
--- /dev/null
+++ b/src/python/test/test_off.py
@@ -0,0 +1,21 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Marc Glisse
+
+ Copyright (C) 2022 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+import gudhi as gd
+import numpy as np
+import pytest
+
+
+def test_off_rw():
+ for dim in range(2, 6):
+ X = np.random.rand(123, dim)
+ gd.write_points_to_off_file("rand.off", X)
+ Y = gd.read_points_from_off_file("rand.off")
+ assert Y == pytest.approx(X)
diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py
new file mode 100644
index 00000000..0e2ac3f8
--- /dev/null
+++ b/src/python/test/test_persistence_graphical_tools.py
@@ -0,0 +1,122 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+import gudhi as gd
+import numpy as np
+import matplotlib as plt
+import pytest
+import warnings
+
+
+def test_array_handler():
+ diags = np.array([[1, 2], [3, 4], [5, 6]], float)
+ arr_diags = gd.persistence_graphical_tools._array_handler(diags)
+ for idx in range(len(diags)):
+ assert arr_diags[idx][0] == 0
+ np.testing.assert_array_equal(arr_diags[idx][1], diags[idx])
+
+ diags = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]
+ arr_diags = gd.persistence_graphical_tools._array_handler(diags)
+ for idx in range(len(diags)):
+ assert arr_diags[idx][0] == 0
+ assert arr_diags[idx][1] == diags[idx]
+
+ diags = [(0, (1.0, 2.0)), (0, (3.0, 4.0)), (0, (5.0, 6.0))]
+ assert gd.persistence_graphical_tools._array_handler(diags) == diags
+
+
+def test_min_birth_max_death():
+ diags = [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ (0, (0.0, 0.118398)),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.117908)),
+ (0, (0.0, 0.112307)),
+ (0, (0.0, 0.107535)),
+ (0, (0.0, 0.106382)),
+ ]
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0.0, 1.0)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (0.0, 5.0)
+
+
+def test_limit_min_birth_max_death():
+ diags = [
+ (0, (2.0, float("inf"))),
+ (0, (2.0, float("inf"))),
+ ]
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2.0, 3.0)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (2.0, 6.0)
+
+
+def test_limit_to_max_intervals():
+ diags = [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ (0, (0.0, 0.118398)),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.117908)),
+ (0, (0.0, 0.112307)),
+ (0, (0.0, 0.107535)),
+ (0, (0.0, 0.106382)),
+ ]
+ # check no warnings if max_intervals equals to the diagrams number
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(
+ diags, 10, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ # check diagrams are not sorted
+ assert truncated_diags == diags
+
+ # check warning if max_intervals lower than the diagrams number
+ with pytest.warns(UserWarning) as record:
+ truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(
+ diags, 5, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ # check diagrams are truncated and sorted by life time
+ assert truncated_diags == [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ ]
+ assert len(record) == 1
+
+
+def _limit_plot_persistence(function):
+ pplot = function(persistence=[])
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[], legend=True)
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[(0, float("inf"))])
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[(0, float("inf"))], legend=True)
+ assert isinstance(pplot, plt.axes.SubplotBase)
+
+
+def test_limit_plot_persistence():
+ for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]:
+ _limit_plot_persistence(function)
+
+
+def _non_existing_persistence_file(function):
+ with pytest.raises(FileNotFoundError):
+ function(persistence_file="pouetpouettralala.toubiloubabdou")
+
+
+def test_non_existing_persistence_file():
+ for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]:
+ _non_existing_persistence_file(function)
diff --git a/src/python/test/test_perslay.py b/src/python/test/test_perslay.py
new file mode 100644
index 00000000..06497712
--- /dev/null
+++ b/src/python/test/test_perslay.py
@@ -0,0 +1,147 @@
+import numpy as np
+import tensorflow as tf
+from sklearn.preprocessing import MinMaxScaler
+from gudhi.tensorflow.perslay import *
+import gudhi.representations as gdr
+
+def test_gaussian_perslay():
+
+ diagrams = [np.array([[0.,4.],[1.,2.],[3.,8.],[6.,8.]])]
+ diagrams = gdr.DiagramScaler(use=True, scalers=[([0,1], MinMaxScaler())]).fit_transform(diagrams)
+ diagrams = tf.RaggedTensor.from_tensor(tf.constant(diagrams, dtype=tf.float32))
+
+ rho = tf.identity
+ phi = GaussianPerslayPhi((5, 5), ((-.5, 1.5), (-.5, 1.5)), .1)
+ weight = PowerPerslayWeight(1.,0.)
+ perm_op = tf.math.reduce_sum
+
+ perslay = Perslay(phi=phi, weight=weight, perm_op=perm_op, rho=rho)
+ vectors = perslay(diagrams)
+
+ print(vectors.shape)
+
+ assert np.linalg.norm(vectors.numpy() - np.array(
+[[[[1.7266072e-16],
+ [4.1706043e-09],
+ [1.1336876e-08],
+ [8.5738821e-12],
+ [2.1243891e-14]],
+
+ [[4.1715076e-09],
+ [1.0074080e-01],
+ [2.7384272e-01],
+ [3.0724244e-02],
+ [7.6157507e-05]],
+
+ [[8.0382870e-06],
+ [1.5802664e+00],
+ [8.2997030e-01],
+ [1.2395413e+01],
+ [3.0724116e-02]],
+
+ [[8.0269419e-06],
+ [1.3065740e+00],
+ [9.0923014e+00],
+ [6.1664842e-02],
+ [1.3949171e-06]],
+
+ [[9.0331329e-13],
+ [1.4954816e-07],
+ [1.5145997e-04],
+ [1.0205092e-06],
+ [7.8093526e-16]]]]) <= 1e-7)
+
+test_gaussian_perslay()
+
+def test_tent_perslay():
+
+ diagrams = [np.array([[0.,4.],[1.,2.],[3.,8.],[6.,8.]])]
+ diagrams = gdr.DiagramScaler(use=True, scalers=[([0,1], MinMaxScaler())]).fit_transform(diagrams)
+ diagrams = tf.RaggedTensor.from_tensor(tf.constant(diagrams, dtype=tf.float32))
+
+ rho = tf.identity
+ phi = TentPerslayPhi(np.array(np.arange(-1.,2.,.1), dtype=np.float32))
+ weight = PowerPerslayWeight(1.,0.)
+ perm_op = 'top3'
+
+ perslay = Perslay(phi=phi, weight=weight, perm_op=perm_op, rho=rho)
+ vectors = perslay(diagrams)
+
+ assert np.linalg.norm(vectors-np.array([[0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0.09999999, 0., 0.,
+ 0.2, 0.05, 0., 0.19999999, 0., 0.,
+ 0.09999999, 0.02500001, 0., 0.125, 0., 0.,
+ 0.22500002, 0., 0., 0.3, 0., 0.,
+ 0.19999999, 0.05000001, 0., 0.10000002, 0.10000002, 0.,
+ 0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0., 0., 0. ]])) <= 1e-7
+
+def test_flat_perslay():
+
+ diagrams = [np.array([[0.,4.],[1.,2.],[3.,8.],[6.,8.]])]
+ diagrams = gdr.DiagramScaler(use=True, scalers=[([0,1], MinMaxScaler())]).fit_transform(diagrams)
+ diagrams = tf.RaggedTensor.from_tensor(tf.constant(diagrams, dtype=tf.float32))
+
+ rho = tf.identity
+ phi = FlatPerslayPhi(np.array(np.arange(-1.,2.,.1), dtype=np.float32), 100.)
+ weight = PowerPerslayWeight(1.,0.)
+ perm_op = tf.math.reduce_sum
+
+ perslay = Perslay(phi=phi, weight=weight, perm_op=perm_op, rho=rho)
+ vectors = perslay(diagrams)
+
+ assert np.linalg.norm(vectors-np.array([[0.0000000e+00, 0.0000000e+00, 1.8048651e-35, 3.9754645e-31, 8.7565101e-27,
+ 1.9287571e-22, 4.2483860e-18, 9.3576392e-14, 2.0611652e-09, 4.5398087e-05,
+ 5.0000376e-01, 1.0758128e+00, 1.9933071e+00, 1.0072457e+00, 1.9240967e+00,
+ 1.4999963e+00, 1.0000458e+00, 1.0066929e+00, 1.9933071e+00, 1.9999092e+00,
+ 1.0000000e+00, 9.0795562e-05, 4.1222914e-09, 1.8715316e-13, 8.4967405e-18,
+ 3.8574998e-22, 1.7512956e-26, 7.9508388e-31, 3.6097302e-35, 0.0000000e+00]]) <= 1e-7)
+
+def test_gmix_weight():
+
+ diagrams = [np.array([[0.,4.],[1.,2.],[3.,8.],[6.,8.]])]
+ diagrams = gdr.DiagramScaler(use=True, scalers=[([0,1], MinMaxScaler())]).fit_transform(diagrams)
+ diagrams = tf.RaggedTensor.from_tensor(tf.constant(diagrams, dtype=tf.float32))
+
+ rho = tf.identity
+ phi = FlatPerslayPhi(np.array(np.arange(-1.,2.,.1), dtype=np.float32), 100.)
+ weight = GaussianMixturePerslayWeight(np.array([[.5],[.5],[5],[5]], dtype=np.float32))
+ perm_op = tf.math.reduce_sum
+
+ perslay = Perslay(phi=phi, weight=weight, perm_op=perm_op, rho=rho)
+ vectors = perslay(diagrams)
+
+ assert np.linalg.norm(vectors-np.array([[0.0000000e+00, 0.0000000e+00, 1.7869064e-35, 3.9359080e-31, 8.6693818e-27,
+ 1.9095656e-22, 4.2061142e-18, 9.2645292e-14, 2.0406561e-09, 4.4946366e-05,
+ 4.9502861e-01, 1.0652492e+00, 1.9753191e+00, 9.9723548e-01, 1.9043801e+00,
+ 1.4844525e+00, 9.8947650e-01, 9.9604094e-01, 1.9703994e+00, 1.9769192e+00,
+ 9.8850453e-01, 8.9751818e-05, 4.0749040e-09, 1.8500175e-13, 8.3990662e-18,
+ 3.8131562e-22, 1.7311636e-26, 7.8594399e-31, 3.5682349e-35, 0.0000000e+00]]) <= 1e-7)
+
+def test_grid_weight():
+
+ diagrams = [np.array([[0.,4.],[1.,2.],[3.,8.],[6.,8.]])]
+ diagrams = gdr.DiagramScaler(use=True, scalers=[([0,1], MinMaxScaler())]).fit_transform(diagrams)
+ diagrams = tf.RaggedTensor.from_tensor(tf.constant(diagrams, dtype=tf.float32))
+
+ rho = tf.identity
+ phi = FlatPerslayPhi(np.array(np.arange(-1.,2.,.1), dtype=np.float32), 100.)
+ weight = GridPerslayWeight(np.array(np.random.uniform(size=[100,100]),dtype=np.float32),((-0.01, 1.01),(-0.01, 1.01)))
+ perm_op = tf.math.reduce_sum
+
+ perslay = Perslay(phi=phi, weight=weight, perm_op=perm_op, rho=rho)
+ vectors = perslay(diagrams)
+
+ assert np.linalg.norm(vectors-np.array([[0.0000000e+00, 0.0000000e+00, 1.5124093e-37, 3.3314498e-33, 7.3379791e-29,
+ 1.6163036e-24, 3.5601592e-20, 7.8417273e-16, 1.7272621e-11, 3.8043717e-07,
+ 4.1902456e-03, 1.7198652e-02, 1.2386327e-01, 9.2694648e-03, 1.9515079e-01,
+ 2.0629172e-01, 2.0210314e-01, 2.0442720e-01, 5.4709727e-01, 5.4939687e-01,
+ 2.7471092e-01, 2.4942532e-05, 1.1324385e-09, 5.1413016e-14, 2.3341474e-18,
+ 1.0596973e-22, 4.8110000e-27, 2.1841823e-31, 9.9163230e-36, 0.0000000e+00]]) <= 1e-7)
diff --git a/src/python/test/test_reader_utils.py b/src/python/test/test_reader_utils.py
index 90da6651..fdfddc4b 100755
--- a/src/python/test/test_reader_utils.py
+++ b/src/python/test/test_reader_utils.py
@@ -8,8 +8,9 @@
- YYYY/MM Author: Description of the modification
"""
-import gudhi
+import gudhi as gd
import numpy as np
+from pytest import raises
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2017 Inria"
@@ -18,7 +19,7 @@ __license__ = "MIT"
def test_non_existing_csv_file():
# Try to open a non existing file
- matrix = gudhi.read_lower_triangular_matrix_from_csv_file(
+ matrix = gd.read_lower_triangular_matrix_from_csv_file(
csv_file="pouetpouettralala.toubiloubabdou"
)
assert matrix == []
@@ -29,8 +30,8 @@ def test_full_square_distance_matrix_csv_file():
test_file = open("full_square_distance_matrix.csv", "w")
test_file.write("0;1;2;3;\n1;0;4;5;\n2;4;0;6;\n3;5;6;0;")
test_file.close()
- matrix = gudhi.read_lower_triangular_matrix_from_csv_file(
- csv_file="full_square_distance_matrix.csv"
+ matrix = gd.read_lower_triangular_matrix_from_csv_file(
+ csv_file="full_square_distance_matrix.csv", separator=";"
)
assert matrix == [[], [1.0], [2.0, 4.0], [3.0, 5.0, 6.0]]
@@ -40,7 +41,7 @@ def test_lower_triangular_distance_matrix_csv_file():
test_file = open("lower_triangular_distance_matrix.csv", "w")
test_file.write("\n1,\n2,3,\n4,5,6,\n7,8,9,10,")
test_file.close()
- matrix = gudhi.read_lower_triangular_matrix_from_csv_file(
+ matrix = gd.read_lower_triangular_matrix_from_csv_file(
csv_file="lower_triangular_distance_matrix.csv", separator=","
)
assert matrix == [[], [1.0], [2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]]
@@ -48,11 +49,11 @@ def test_lower_triangular_distance_matrix_csv_file():
def test_non_existing_persistence_file():
# Try to open a non existing file
- persistence = gudhi.read_persistence_intervals_grouped_by_dimension(
+ persistence = gd.read_persistence_intervals_grouped_by_dimension(
persistence_file="pouetpouettralala.toubiloubabdou"
)
assert persistence == []
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="pouetpouettralala.toubiloubabdou", only_this_dim=1
)
np.testing.assert_array_equal(persistence, [])
@@ -65,21 +66,21 @@ def test_read_persistence_intervals_without_dimension():
"# Simple persistence diagram without dimension\n2.7 3.7\n9.6 14.\n34.2 34.974\n3. inf"
)
test_file.close()
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_without_dimension.pers"
)
np.testing.assert_array_equal(
persistence, [(2.7, 3.7), (9.6, 14.0), (34.2, 34.974), (3.0, float("Inf"))]
)
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_without_dimension.pers", only_this_dim=0
)
np.testing.assert_array_equal(persistence, [])
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_without_dimension.pers", only_this_dim=1
)
np.testing.assert_array_equal(persistence, [])
- persistence = gudhi.read_persistence_intervals_grouped_by_dimension(
+ persistence = gd.read_persistence_intervals_grouped_by_dimension(
persistence_file="persistence_intervals_without_dimension.pers"
)
assert persistence == {
@@ -94,29 +95,29 @@ def test_read_persistence_intervals_with_dimension():
"# Simple persistence diagram with dimension\n0 2.7 3.7\n1 9.6 14.\n3 34.2 34.974\n1 3. inf"
)
test_file.close()
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers"
)
np.testing.assert_array_equal(
persistence, [(2.7, 3.7), (9.6, 14.0), (34.2, 34.974), (3.0, float("Inf"))]
)
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers", only_this_dim=0
)
np.testing.assert_array_equal(persistence, [(2.7, 3.7)])
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers", only_this_dim=1
)
np.testing.assert_array_equal(persistence, [(9.6, 14.0), (3.0, float("Inf"))])
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers", only_this_dim=2
)
np.testing.assert_array_equal(persistence, [])
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers", only_this_dim=3
)
np.testing.assert_array_equal(persistence, [(34.2, 34.974)])
- persistence = gudhi.read_persistence_intervals_grouped_by_dimension(
+ persistence = gd.read_persistence_intervals_grouped_by_dimension(
persistence_file="persistence_intervals_with_dimension.pers"
)
assert persistence == {
diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py
new file mode 100644
index 00000000..e5d2de82
--- /dev/null
+++ b/src/python/test/test_remote_datasets.py
@@ -0,0 +1,87 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Hind Montassif
+#
+# Copyright (C) 2021 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+from gudhi.datasets import remote
+
+import shutil
+import io
+import sys
+import pytest
+
+from os.path import isdir, expanduser, exists
+from os import remove, environ
+
+def test_data_home():
+ # Test _get_data_home and clear_data_home on new empty folder
+ empty_data_home = remote._get_data_home(data_home="empty_folder_for_test")
+ assert isdir(empty_data_home)
+
+ remote.clear_data_home(data_home=empty_data_home)
+ assert not isdir(empty_data_home)
+
+def test_fetch_remote():
+ # Test fetch with a wrong checksum
+ with pytest.raises(OSError):
+ remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "tmp_spiral_2d.npy", file_checksum = 'XXXXXXXXXX')
+ assert not exists("tmp_spiral_2d.npy")
+
+def _get_bunny_license_print(accept_license = False):
+ capturedOutput = io.StringIO()
+ # Redirect stdout
+ sys.stdout = capturedOutput
+
+ bunny_arr = remote.fetch_bunny("./tmp_for_test/bunny.npy", accept_license)
+ assert bunny_arr.shape == (35947, 3)
+ del bunny_arr
+ remove("./tmp_for_test/bunny.npy")
+
+ # Reset redirect
+ sys.stdout = sys.__stdout__
+ return capturedOutput
+
+def test_print_bunny_license():
+ # Test not printing bunny.npy LICENSE when accept_license = True
+ assert "" == _get_bunny_license_print(accept_license = True).getvalue()
+ # Test printing bunny.LICENSE file when fetching bunny.npy with accept_license = False (default)
+ with open("./tmp_for_test/bunny.LICENSE") as f:
+ assert f.read().rstrip("\n") == _get_bunny_license_print().getvalue().rstrip("\n")
+ shutil.rmtree("./tmp_for_test")
+
+def test_fetch_remote_datasets_wrapped():
+ # Test fetch_spiral_2d and fetch_bunny wrapping functions with data directory different from default (twice, to test case of already fetched files)
+ # Default case is not tested because it would fail in case the user sets the 'GUDHI_DATA' environment variable locally
+ for i in range(2):
+ spiral_2d_arr = remote.fetch_spiral_2d("./another_fetch_folder_for_test/spiral_2d.npy")
+ assert spiral_2d_arr.shape == (114562, 2)
+
+ bunny_arr = remote.fetch_bunny("./another_fetch_folder_for_test/bunny.npy")
+ assert bunny_arr.shape == (35947, 3)
+
+ # Check that the directory was created
+ assert isdir("./another_fetch_folder_for_test")
+ # Check downloaded files
+ assert exists("./another_fetch_folder_for_test/spiral_2d.npy")
+ assert exists("./another_fetch_folder_for_test/bunny.npy")
+ assert exists("./another_fetch_folder_for_test/bunny.LICENSE")
+
+ # Remove test folders
+ del spiral_2d_arr
+ del bunny_arr
+ shutil.rmtree("./another_fetch_folder_for_test")
+
+def test_gudhi_data_env():
+ # Set environment variable "GUDHI_DATA"
+ environ["GUDHI_DATA"] = "./test_folder_from_env_var"
+ bunny_arr = remote.fetch_bunny()
+ assert bunny_arr.shape == (35947, 3)
+ assert exists("./test_folder_from_env_var/points/bunny/bunny.npy")
+ assert exists("./test_folder_from_env_var/points/bunny/bunny.LICENSE")
+ # Remove test folder
+ del bunny_arr
+ shutil.rmtree("./test_folder_from_env_var")
diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py
index 43c914f3..f4ffbdc1 100755
--- a/src/python/test/test_representations.py
+++ b/src/python/test/test_representations.py
@@ -3,9 +3,23 @@ import sys
import matplotlib.pyplot as plt
import numpy as np
import pytest
+import random
from sklearn.cluster import KMeans
+# Vectorization
+from gudhi.representations import (Landscape, Silhouette, BettiCurve, ComplexPolynomial,\
+ TopologicalVector, PersistenceImage, Entropy)
+
+# Preprocessing
+from gudhi.representations import (BirthPersistenceTransform, Clamping, DiagramScaler, Padding, ProminentPoints, \
+ DiagramSelector)
+
+# Kernel
+from gudhi.representations import (PersistenceWeightedGaussianKernel, \
+ PersistenceScaleSpaceKernel, SlicedWassersteinDistance,\
+ SlicedWassersteinKernel, PersistenceFisherKernel, WassersteinDistance)
+
def test_representations_examples():
# Disable graphics for testing purposes
@@ -46,6 +60,32 @@ def test_multiple():
assert d1 == pytest.approx(d2, rel=0.02)
+# Test sorted values as points order can be inverted, and sorted test is not documentation-friendly
+# Note the test below must be up to date with the Atol class documentation
+def test_atol_doc():
+ a = np.array([[1, 2, 4], [1, 4, 0], [1, 0, 4]])
+ b = np.array([[4, 2, 0], [4, 4, 0], [4, 0, 2]])
+ c = np.array([[3, 2, -1], [1, 2, -1]])
+
+ atol_vectoriser = Atol(quantiser=KMeans(n_clusters=2, random_state=202006))
+ # Atol will do
+ # X = np.concatenate([a,b,c])
+ # kmeans = KMeans(n_clusters=2, random_state=202006).fit(X)
+ # kmeans.labels_ will be : array([1, 0, 1, 0, 0, 1, 0, 0])
+ first_cluster = np.asarray([a[0], a[2], b[2]])
+ second_cluster = np.asarray([a[1], b[0], b[2], c[0], c[1]])
+
+ # Check the center of the first_cluster and second_cluster are in Atol centers
+ centers = atol_vectoriser.fit(X=[a, b, c]).centers
+ np.isclose(centers, first_cluster.mean(axis=0)).all(1).any()
+ np.isclose(centers, second_cluster.mean(axis=0)).all(1).any()
+
+ vectorization = atol_vectoriser.transform(X=[a, b, c])
+ assert np.allclose(vectorization[0], atol_vectoriser(a))
+ assert np.allclose(vectorization[1], atol_vectoriser(b))
+ assert np.allclose(vectorization[2], atol_vectoriser(c))
+
+
def test_dummy_atol():
a = np.array([[1, 2, 4], [1, 4, 0], [1, 0, 4]])
b = np.array([[4, 2, 0], [4, 4, 0], [4, 0, 2]])
@@ -65,10 +105,165 @@ def test_dummy_atol():
from gudhi.representations.vector_methods import BettiCurve
-
def test_infinity():
a = np.array([[1.0, 8.0], [2.0, np.inf], [3.0, 4.0]])
c = BettiCurve(20, [0.0, 10.0])(a)
assert c[1] == 0
assert c[7] == 3
assert c[9] == 2
+
+def test_preprocessing_empty_diagrams():
+ empty_diag = np.empty(shape = [0, 2])
+ assert not np.any(BirthPersistenceTransform()(empty_diag))
+ assert not np.any(Clamping().fit_transform(empty_diag))
+ assert not np.any(DiagramScaler()(empty_diag))
+ assert not np.any(Padding()(empty_diag))
+ assert not np.any(ProminentPoints()(empty_diag))
+ assert not np.any(DiagramSelector()(empty_diag))
+
+def pow(n):
+ return lambda x: np.power(x[1]-x[0],n)
+
+def test_vectorization_empty_diagrams():
+ empty_diag = np.empty(shape = [0, 2])
+ random_resolution = random.randint(50,100)*10 # between 500 and 1000
+ print("resolution = ", random_resolution)
+ lsc = Landscape(resolution=random_resolution)(empty_diag)
+ assert not np.any(lsc)
+ assert lsc.shape[0]%random_resolution == 0
+ slt = Silhouette(resolution=random_resolution, weight=pow(2))(empty_diag)
+ assert not np.any(slt)
+ assert slt.shape[0] == random_resolution
+ btc = BettiCurve(resolution=random_resolution)(empty_diag)
+ assert not np.any(btc)
+ assert btc.shape[0] == random_resolution
+ cpp = ComplexPolynomial(threshold=random_resolution, polynomial_type="T")(empty_diag)
+ assert not np.any(cpp)
+ assert cpp.shape[0] == random_resolution
+ tpv = TopologicalVector(threshold=random_resolution)(empty_diag)
+ assert tpv.shape[0] == random_resolution
+ assert not np.any(tpv)
+ prmg = PersistenceImage(resolution=[random_resolution,random_resolution])(empty_diag)
+ assert not np.any(prmg)
+ assert prmg.shape[0] == random_resolution * random_resolution
+ sce = Entropy(mode="scalar", resolution=random_resolution)(empty_diag)
+ assert not np.any(sce)
+ assert sce.shape[0] == 1
+ scv = Entropy(mode="vector", normalized=False, resolution=random_resolution)(empty_diag)
+ assert not np.any(scv)
+ assert scv.shape[0] == random_resolution
+
+def test_entropy_miscalculation():
+ diag_ex = np.array([[0.0,1.0], [0.0,1.0], [0.0,2.0]])
+ def pe(pd):
+ l = pd[:,1] - pd[:,0]
+ l = l/sum(l)
+ return -np.dot(l, np.log(l))
+ sce = Entropy(mode="scalar")
+ assert [[pe(diag_ex)]] == sce.fit_transform([diag_ex])
+ sce = Entropy(mode="vector", resolution=4, normalized=False, keep_endpoints=True)
+ pef = [-1/4*np.log(1/4)-1/4*np.log(1/4)-1/2*np.log(1/2),
+ -1/4*np.log(1/4)-1/4*np.log(1/4)-1/2*np.log(1/2),
+ -1/2*np.log(1/2),
+ 0.0]
+ assert all(([pef] == sce.fit_transform([diag_ex]))[0])
+ sce = Entropy(mode="vector", resolution=4, normalized=True)
+ pefN = (sce.fit_transform([diag_ex]))[0]
+ area = np.linalg.norm(pefN, ord=1)
+ assert area==pytest.approx(1)
+
+def test_kernel_empty_diagrams():
+ empty_diag = np.empty(shape = [0, 2])
+ assert SlicedWassersteinDistance(num_directions=100)(empty_diag, empty_diag) == 0.
+ assert SlicedWassersteinKernel(num_directions=100, bandwidth=1.)(empty_diag, empty_diag) == 1.
+ assert WassersteinDistance(mode="hera", delta=0.0001)(empty_diag, empty_diag) == 0.
+ assert WassersteinDistance(mode="pot")(empty_diag, empty_diag) == 0.
+ assert BottleneckDistance(epsilon=.001)(empty_diag, empty_diag) == 0.
+ assert BottleneckDistance()(empty_diag, empty_diag) == 0.
+# PersistenceWeightedGaussianKernel(bandwidth=1., kernel_approx=None, weight=arctan(1.,1.))(empty_diag, empty_diag)
+# PersistenceWeightedGaussianKernel(kernel_approx=RBFSampler(gamma=1./2, n_components=100000).fit(np.ones([1,2])), weight=arctan(1.,1.))(empty_diag, empty_diag)
+# PersistenceScaleSpaceKernel(bandwidth=1.)(empty_diag, empty_diag)
+# PersistenceScaleSpaceKernel(kernel_approx=RBFSampler(gamma=1./2, n_components=100000).fit(np.ones([1,2])))(empty_diag, empty_diag)
+# PersistenceFisherKernel(bandwidth_fisher=1., bandwidth=1.)(empty_diag, empty_diag)
+# PersistenceFisherKernel(bandwidth_fisher=1., bandwidth=1., kernel_approx=RBFSampler(gamma=1./2, n_components=100000).fit(np.ones([1,2])))(empty_diag, empty_diag)
+
+
+def test_silhouette_permutation_invariance():
+ dgm = _n_diags(1)[0]
+ dgm_permuted = dgm[np.random.permutation(dgm.shape[0]).astype(int)]
+ random_resolution = random.randint(50, 100) * 10
+ slt = Silhouette(resolution=random_resolution, weight=pow(2))
+
+ assert np.all(np.isclose(slt(dgm), slt(dgm_permuted)))
+
+
+def test_silhouette_multiplication_invariance():
+ dgm = _n_diags(1)[0]
+ n_repetitions = np.random.randint(2, high=10)
+ dgm_augmented = np.repeat(dgm, repeats=n_repetitions, axis=0)
+
+ random_resolution = random.randint(50, 100) * 10
+ slt = Silhouette(resolution=random_resolution, weight=pow(2))
+ assert np.all(np.isclose(slt(dgm), slt(dgm_augmented)))
+
+
+def test_silhouette_numeric():
+ dgm = np.array([[2., 3.], [5., 6.]])
+ slt = Silhouette(resolution=9, weight=pow(1), sample_range=[2., 6.])
+ #slt.fit([dgm])
+ # x_values = array([2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6.])
+
+ expected_silhouette = np.array([0., 0.5, 0., 0., 0., 0., 0., 0.5, 0.])/np.sqrt(2)
+ output_silhouette = slt(dgm)
+ assert np.all(np.isclose(output_silhouette, expected_silhouette))
+
+
+def test_landscape_small_persistence_invariance():
+ dgm = np.array([[2., 6.], [2., 5.], [3., 7.]])
+ small_persistence_pts = np.random.rand(10, 2)
+ small_persistence_pts[:, 1] += small_persistence_pts[:, 0]
+ small_persistence_pts += np.min(dgm)
+ dgm_augmented = np.concatenate([dgm, small_persistence_pts], axis=0)
+
+ lds = Landscape(num_landscapes=2, resolution=5)
+ lds_dgm, lds_dgm_augmented = lds(dgm), lds(dgm_augmented)
+
+ assert np.all(np.isclose(lds_dgm, lds_dgm_augmented))
+
+
+def test_landscape_numeric():
+ dgm = np.array([[2., 6.], [3., 5.]])
+ lds_ref = np.array([
+ 0., 0.5, 1., 1.5, 2., 1.5, 1., 0.5, 0., # tent of [2, 6]
+ 0., 0., 0., 0.5, 1., 0.5, 0., 0., 0.,
+ 0., 0., 0., 0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0., 0., 0., 0., 0., 0.,
+ ])
+ lds_ref *= np.sqrt(2)
+ lds = Landscape(num_landscapes=4, resolution=9, sample_range=[2., 6.])
+ lds_dgm = lds(dgm)
+ assert np.all(np.isclose(lds_dgm, lds_ref))
+
+
+def test_landscape_nan_range():
+ dgm = np.array([[2., 6.], [3., 5.]])
+ lds = Landscape(num_landscapes=2, resolution=9, sample_range=[np.nan, 6.])
+ lds_dgm = lds(dgm)
+ assert (lds.sample_range_fixed[0] == 2) & (lds.sample_range_fixed[1] == 6)
+ assert lds.new_resolution == 10
+
+def test_endpoints():
+ diags = [ np.array([[2., 3.]]) ]
+ for vec in [ Landscape(), Silhouette(), BettiCurve(), Entropy(mode="vector") ]:
+ vec.fit(diags)
+ assert vec.grid_[0] > 2 and vec.grid_[-1] < 3
+ for vec in [ Landscape(keep_endpoints=True), Silhouette(keep_endpoints=True), BettiCurve(keep_endpoints=True), Entropy(mode="vector", keep_endpoints=True)]:
+ vec.fit(diags)
+ assert vec.grid_[0] == 2 and vec.grid_[-1] == 3
+ vec = BettiCurve(resolution=None)
+ vec.fit(diags)
+ assert np.equal(vec.grid_, [-np.inf, 2., 3.]).all()
+
+def test_get_params():
+ for vec in [ Landscape(), Silhouette(), BettiCurve(), Entropy(mode="vector") ]:
+ vec.get_params()
diff --git a/src/python/test/test_representations_preprocessing.py b/src/python/test/test_representations_preprocessing.py
new file mode 100644
index 00000000..838cf30c
--- /dev/null
+++ b/src/python/test/test_representations_preprocessing.py
@@ -0,0 +1,39 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.representations.preprocessing import DimensionSelector
+import numpy as np
+import pytest
+
+H0_0 = np.array([0.0, 0.0])
+H1_0 = np.array([1.0, 0.0])
+H0_1 = np.array([0.0, 1.0])
+H1_1 = np.array([1.0, 1.0])
+H0_2 = np.array([0.0, 2.0])
+H1_2 = np.array([1.0, 2.0])
+
+
+def test_dimension_selector():
+ X = [[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]]
+ ds = DimensionSelector(index=0)
+ h0 = ds.fit_transform(X)
+ np.testing.assert_array_equal(h0[0], H0_0)
+ np.testing.assert_array_equal(h0[1], H0_1)
+ np.testing.assert_array_equal(h0[2], H0_2)
+
+ ds = DimensionSelector(index=1)
+ h1 = ds.fit_transform(X)
+ np.testing.assert_array_equal(h1[0], H1_0)
+ np.testing.assert_array_equal(h1[1], H1_1)
+ np.testing.assert_array_equal(h1[2], H1_2)
+
+ ds = DimensionSelector(index=2)
+ with pytest.raises(IndexError):
+ h2 = ds.fit_transform([[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]])
diff --git a/src/python/test/test_rips_complex.py b/src/python/test/test_rips_complex.py
index b86e7498..a2f43a1b 100755
--- a/src/python/test/test_rips_complex.py
+++ b/src/python/test/test_rips_complex.py
@@ -133,3 +133,24 @@ def test_filtered_rips_from_distance_matrix():
assert simplex_tree.num_simplices() == 8
assert simplex_tree.num_vertices() == 4
+
+
+def test_sparse_with_multiplicity():
+ points = [
+ [3, 4],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [3, 4.1],
+ ]
+ rips = RipsComplex(points=points, sparse=0.01)
+ simplex_tree = rips.create_simplex_tree(max_dimension=2)
+ assert simplex_tree.num_simplices() == 7
+ diag = simplex_tree.persistence()
diff --git a/src/python/test/test_simplex_generators.py b/src/python/test/test_simplex_generators.py
index 8a9b4844..c567d4c1 100755
--- a/src/python/test/test_simplex_generators.py
+++ b/src/python/test/test_simplex_generators.py
@@ -14,7 +14,7 @@ import numpy as np
def test_flag_generators():
pts = np.array([[0, 0], [0, 1.01], [1, 0], [1.02, 1.03], [100, 0], [100, 3.01], [103, 0], [103.02, 3.03]])
- r = gudhi.RipsComplex(pts, max_edge_length=4)
+ r = gudhi.RipsComplex(points=pts, max_edge_length=4)
st = r.create_simplex_tree(max_dimension=50)
st.persistence()
g = st.flag_persistence_generators()
diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py
index a3eacaa9..2ccbfbf5 100755
--- a/src/python/test/test_simplex_tree.py
+++ b/src/python/test/test_simplex_tree.py
@@ -8,7 +8,8 @@
- YYYY/MM Author: Description of the modification
"""
-from gudhi import SimplexTree, __GUDHI_USE_EIGEN3
+from gudhi import SimplexTree
+import numpy as np
import pytest
__author__ = "Vincent Rouvreau"
@@ -248,6 +249,7 @@ def test_make_filtration_non_decreasing():
assert st.filtration([3, 4]) == 2.0
assert st.filtration([4, 5]) == 2.0
+
def test_extend_filtration():
# Inserted simplex:
@@ -256,82 +258,87 @@ def test_extend_filtration():
# / \ /
# o o
# /2\ /3
- # o o
- # 1 0
-
- st = SimplexTree()
- st.insert([0,2])
- st.insert([1,2])
- st.insert([0,3])
- st.insert([2,5])
- st.insert([3,4])
- st.insert([3,5])
- st.assign_filtration([0], 1.)
- st.assign_filtration([1], 2.)
- st.assign_filtration([2], 3.)
- st.assign_filtration([3], 4.)
- st.assign_filtration([4], 5.)
- st.assign_filtration([5], 6.)
-
- assert list(st.get_filtration()) == [
- ([0, 2], 0.0),
- ([1, 2], 0.0),
- ([0, 3], 0.0),
- ([3, 4], 0.0),
- ([2, 5], 0.0),
- ([3, 5], 0.0),
- ([0], 1.0),
- ([1], 2.0),
- ([2], 3.0),
- ([3], 4.0),
- ([4], 5.0),
- ([5], 6.0)
+ # o o
+ # 1 0
+
+ st = SimplexTree()
+ st.insert([0, 2])
+ st.insert([1, 2])
+ st.insert([0, 3])
+ st.insert([2, 5])
+ st.insert([3, 4])
+ st.insert([3, 5])
+ st.assign_filtration([0], 1.0)
+ st.assign_filtration([1], 2.0)
+ st.assign_filtration([2], 3.0)
+ st.assign_filtration([3], 4.0)
+ st.assign_filtration([4], 5.0)
+ st.assign_filtration([5], 6.0)
+
+ assert list(st.get_filtration()) == [
+ ([0, 2], 0.0),
+ ([1, 2], 0.0),
+ ([0, 3], 0.0),
+ ([3, 4], 0.0),
+ ([2, 5], 0.0),
+ ([3, 5], 0.0),
+ ([0], 1.0),
+ ([1], 2.0),
+ ([2], 3.0),
+ ([3], 4.0),
+ ([4], 5.0),
+ ([5], 6.0),
]
-
+
st.extend_filtration()
-
- assert list(st.get_filtration()) == [
- ([6], -3.0),
- ([0], -2.0),
- ([1], -1.8),
- ([2], -1.6),
- ([0, 2], -1.6),
- ([1, 2], -1.6),
- ([3], -1.4),
- ([0, 3], -1.4),
- ([4], -1.2),
- ([3, 4], -1.2),
- ([5], -1.0),
- ([2, 5], -1.0),
- ([3, 5], -1.0),
- ([5, 6], 1.0),
- ([4, 6], 1.2),
- ([3, 6], 1.4),
+
+ assert list(st.get_filtration()) == [
+ ([6], -3.0),
+ ([0], -2.0),
+ ([1], -1.8),
+ ([2], -1.6),
+ ([0, 2], -1.6),
+ ([1, 2], -1.6),
+ ([3], -1.4),
+ ([0, 3], -1.4),
+ ([4], -1.2),
+ ([3, 4], -1.2),
+ ([5], -1.0),
+ ([2, 5], -1.0),
+ ([3, 5], -1.0),
+ ([5, 6], 1.0),
+ ([4, 6], 1.2),
+ ([3, 6], 1.4),
([3, 4, 6], 1.4),
- ([3, 5, 6], 1.4),
- ([2, 6], 1.6),
- ([2, 5, 6], 1.6),
- ([1, 6], 1.8),
- ([1, 2, 6], 1.8),
- ([0, 6], 2.0),
- ([0, 2, 6], 2.0),
- ([0, 3, 6], 2.0)
+ ([3, 5, 6], 1.4),
+ ([2, 6], 1.6),
+ ([2, 5, 6], 1.6),
+ ([1, 6], 1.8),
+ ([1, 2, 6], 1.8),
+ ([0, 6], 2.0),
+ ([0, 2, 6], 2.0),
+ ([0, 3, 6], 2.0),
]
- dgms = st.extended_persistence(min_persistence=-1.)
+ dgms = st.extended_persistence(min_persistence=-1.0)
+ assert len(dgms) == 4
+ # Sort by (death-birth) descending - we are only interested in those with the longest life span
+ for idx in range(4):
+ dgms[idx] = sorted(dgms[idx], key=lambda x: (-abs(x[1][0] - x[1][1])))
+
+ assert dgms[0][0][1][0] == pytest.approx(2.0)
+ assert dgms[0][0][1][1] == pytest.approx(3.0)
+ assert dgms[1][0][1][0] == pytest.approx(5.0)
+ assert dgms[1][0][1][1] == pytest.approx(4.0)
+ assert dgms[2][0][1][0] == pytest.approx(1.0)
+ assert dgms[2][0][1][1] == pytest.approx(6.0)
+ assert dgms[3][0][1][0] == pytest.approx(6.0)
+ assert dgms[3][0][1][1] == pytest.approx(1.0)
- assert dgms[0][0][1][0] == pytest.approx(2.)
- assert dgms[0][0][1][1] == pytest.approx(3.)
- assert dgms[1][0][1][0] == pytest.approx(5.)
- assert dgms[1][0][1][1] == pytest.approx(4.)
- assert dgms[2][0][1][0] == pytest.approx(1.)
- assert dgms[2][0][1][1] == pytest.approx(6.)
- assert dgms[3][0][1][0] == pytest.approx(6.)
- assert dgms[3][0][1][1] == pytest.approx(1.)
def test_simplices_iterator():
st = SimplexTree()
-
+
assert st.insert([0, 1, 2], filtration=4.0) == True
assert st.insert([2, 3, 4], filtration=2.0) == True
@@ -341,9 +348,10 @@ def test_simplices_iterator():
print("filtration is: ", simplex[1])
assert st.filtration(simplex[0]) == simplex[1]
+
def test_collapse_edges():
st = SimplexTree()
-
+
assert st.insert([0, 1], filtration=1.0) == True
assert st.insert([1, 2], filtration=1.0) == True
assert st.insert([2, 3], filtration=1.0) == True
@@ -353,38 +361,35 @@ def test_collapse_edges():
assert st.num_simplices() == 10
- if __GUDHI_USE_EIGEN3:
- st.collapse_edges()
- assert st.num_simplices() == 9
- assert st.find([1, 3]) == False
- for simplex in st.get_skeleton(0):
- assert simplex[1] == 1.
- else:
- # If no Eigen3, collapse_edges throws an exception
- with pytest.raises(RuntimeError):
- st.collapse_edges()
+ st.collapse_edges()
+ assert st.num_simplices() == 9
+ assert st.find([0, 2]) == False # [1, 3] would be fine as well
+ for simplex in st.get_skeleton(0):
+ assert simplex[1] == 1.0
+
def test_reset_filtration():
st = SimplexTree()
-
- assert st.insert([0, 1, 2], 3.) == True
- assert st.insert([0, 3], 2.) == True
- assert st.insert([3, 4, 5], 3.) == True
- assert st.insert([0, 1, 6, 7], 4.) == True
+
+ assert st.insert([0, 1, 2], 3.0) == True
+ assert st.insert([0, 3], 2.0) == True
+ assert st.insert([3, 4, 5], 3.0) == True
+ assert st.insert([0, 1, 6, 7], 4.0) == True
# Guaranteed by construction
for simplex in st.get_simplices():
- assert st.filtration(simplex[0]) >= 2.
-
+ assert st.filtration(simplex[0]) >= 2.0
+
# dimension until 5 even if simplex tree is of dimension 3 to test the limits
for dimension in range(5, -1, -1):
- st.reset_filtration(0., dimension)
+ st.reset_filtration(0.0, dimension)
for simplex in st.get_skeleton(3):
print(simplex)
if len(simplex[0]) < (dimension) + 1:
- assert st.filtration(simplex[0]) >= 2.
+ assert st.filtration(simplex[0]) >= 2.0
else:
- assert st.filtration(simplex[0]) == 0.
+ assert st.filtration(simplex[0]) == 0.0
+
def test_boundaries_iterator():
st = SimplexTree()
@@ -400,7 +405,242 @@ def test_boundaries_iterator():
list(st.get_boundaries([]))
with pytest.raises(RuntimeError):
- list(st.get_boundaries([0, 4])) # (0, 4) does not exist
+ list(st.get_boundaries([0, 4])) # (0, 4) does not exist
with pytest.raises(RuntimeError):
- list(st.get_boundaries([6])) # (6) does not exist
+ list(st.get_boundaries([6])) # (6) does not exist
+
+
+def test_persistence_intervals_in_dimension():
+ # Here is our triangulation of a 2-torus - taken from https://dioscuri-tda.org/Paris_TDA_Tutorial_2021.html
+ # 0-----3-----4-----0
+ # | \ | \ | \ | \ |
+ # | \ | \ | \| \ |
+ # 1-----8-----7-----1
+ # | \ | \ | \ | \ |
+ # | \ | \ | \ | \ |
+ # 2-----5-----6-----2
+ # | \ | \ | \ | \ |
+ # | \ | \ | \ | \ |
+ # 0-----3-----4-----0
+ st = SimplexTree()
+ st.insert([0, 1, 8])
+ st.insert([0, 3, 8])
+ st.insert([3, 7, 8])
+ st.insert([3, 4, 7])
+ st.insert([1, 4, 7])
+ st.insert([0, 1, 4])
+ st.insert([1, 2, 5])
+ st.insert([1, 5, 8])
+ st.insert([5, 6, 8])
+ st.insert([6, 7, 8])
+ st.insert([2, 6, 7])
+ st.insert([1, 2, 7])
+ st.insert([0, 2, 3])
+ st.insert([2, 3, 5])
+ st.insert([3, 4, 5])
+ st.insert([4, 5, 6])
+ st.insert([0, 4, 6])
+ st.insert([0, 2, 6])
+ st.compute_persistence(persistence_dim_max=True)
+
+ H0 = st.persistence_intervals_in_dimension(0)
+ assert np.array_equal(H0, np.array([[0.0, float("inf")]]))
+ H1 = st.persistence_intervals_in_dimension(1)
+ assert np.array_equal(H1, np.array([[0.0, float("inf")], [0.0, float("inf")]]))
+ H2 = st.persistence_intervals_in_dimension(2)
+ assert np.array_equal(H2, np.array([[0.0, float("inf")]]))
+ # Test empty case
+ assert st.persistence_intervals_in_dimension(3).shape == (0, 2)
+
+
+def test_equality_operator():
+ st1 = SimplexTree()
+ st2 = SimplexTree()
+
+ assert st1 == st2
+
+ st1.insert([1, 2, 3], 4.0)
+ assert st1 != st2
+
+ st2.insert([1, 2, 3], 4.0)
+ assert st1 == st2
+
+
+def test_simplex_tree_deep_copy():
+ st = SimplexTree()
+ st.insert([1, 2, 3], 0.0)
+ # compute persistence only on the original
+ st.compute_persistence()
+
+ st_copy = st.copy()
+ assert st_copy == st
+ st_filt_list = list(st.get_filtration())
+
+ # check persistence is not copied
+ assert st.__is_persistence_defined() == True
+ assert st_copy.__is_persistence_defined() == False
+
+ # remove something in the copy and check the copy is included in the original
+ st_copy.remove_maximal_simplex([1, 2, 3])
+ a_filt_list = list(st_copy.get_filtration())
+ assert len(a_filt_list) < len(st_filt_list)
+
+ for a_splx in a_filt_list:
+ assert a_splx in st_filt_list
+
+ # test double free
+ del st
+ del st_copy
+
+
+def test_simplex_tree_deep_copy_constructor():
+ st = SimplexTree()
+ st.insert([1, 2, 3], 0.0)
+ # compute persistence only on the original
+ st.compute_persistence()
+
+ st_copy = SimplexTree(st)
+ assert st_copy == st
+ st_filt_list = list(st.get_filtration())
+
+ # check persistence is not copied
+ assert st.__is_persistence_defined() == True
+ assert st_copy.__is_persistence_defined() == False
+
+ # remove something in the copy and check the copy is included in the original
+ st_copy.remove_maximal_simplex([1, 2, 3])
+ a_filt_list = list(st_copy.get_filtration())
+ assert len(a_filt_list) < len(st_filt_list)
+
+ for a_splx in a_filt_list:
+ assert a_splx in st_filt_list
+
+ # test double free
+ del st
+ del st_copy
+
+
+def test_simplex_tree_constructor_exception():
+ with pytest.raises(TypeError):
+ st = SimplexTree(other="Construction from a string shall raise an exception")
+
+
+def test_create_from_array():
+ a = np.array([[1, 4, 13, 6], [4, 3, 11, 5], [13, 11, 10, 12], [6, 5, 12, 2]])
+ st = SimplexTree.create_from_array(a, max_filtration=5.0)
+ assert list(st.get_filtration()) == [([0], 1.0), ([3], 2.0), ([1], 3.0), ([0, 1], 4.0), ([1, 3], 5.0)]
+
+
+def test_insert_edges_from_coo_matrix():
+ try:
+ from scipy.sparse import coo_matrix
+ from scipy.spatial import cKDTree
+ except ImportError:
+ print("Skipping, no SciPy")
+ return
+
+ st = SimplexTree()
+ st.insert([1, 2, 7], 7)
+ row = np.array([2, 5, 3])
+ col = np.array([1, 4, 6])
+ dat = np.array([1, 2, 3])
+ edges = coo_matrix((dat, (row, col)))
+ st.insert_edges_from_coo_matrix(edges)
+ assert list(st.get_filtration()) == [
+ ([1], 1.0),
+ ([2], 1.0),
+ ([1, 2], 1.0),
+ ([4], 2.0),
+ ([5], 2.0),
+ ([4, 5], 2.0),
+ ([3], 3.0),
+ ([6], 3.0),
+ ([3, 6], 3.0),
+ ([7], 7.0),
+ ([1, 7], 7.0),
+ ([2, 7], 7.0),
+ ([1, 2, 7], 7.0),
+ ]
+
+ pts = np.random.rand(100, 2)
+ tree = cKDTree(pts)
+ edges = tree.sparse_distance_matrix(tree, max_distance=0.15, output_type="coo_matrix")
+ st = SimplexTree()
+ st.insert_edges_from_coo_matrix(edges)
+ assert 100 < st.num_simplices() < 1000
+
+
+def test_insert_batch():
+ st = SimplexTree()
+ # vertices
+ st.insert_batch(np.array([[6, 1, 5]]), np.array([-5.0, 2.0, -3.0]))
+ # triangles
+ st.insert_batch(np.array([[2, 10], [5, 0], [6, 11]]), np.array([4.0, 0.0]))
+ # edges
+ st.insert_batch(np.array([[1, 5], [2, 5]]), np.array([1.0, 3.0]))
+
+ assert list(st.get_filtration()) == [
+ ([6], -5.0),
+ ([5], -3.0),
+ ([0], 0.0),
+ ([10], 0.0),
+ ([0, 10], 0.0),
+ ([11], 0.0),
+ ([0, 11], 0.0),
+ ([10, 11], 0.0),
+ ([0, 10, 11], 0.0),
+ ([1], 1.0),
+ ([2], 1.0),
+ ([1, 2], 1.0),
+ ([2, 5], 4.0),
+ ([2, 6], 4.0),
+ ([5, 6], 4.0),
+ ([2, 5, 6], 4.0),
+ ]
+
+
+def test_expansion_with_blocker():
+ st = SimplexTree()
+ st.insert([0, 1], 0)
+ st.insert([0, 2], 1)
+ st.insert([0, 3], 2)
+ st.insert([1, 2], 3)
+ st.insert([1, 3], 4)
+ st.insert([2, 3], 5)
+ st.insert([2, 4], 6)
+ st.insert([3, 6], 7)
+ st.insert([4, 5], 8)
+ st.insert([4, 6], 9)
+ st.insert([5, 6], 10)
+ st.insert([6], 10)
+
+ def blocker(simplex):
+ try:
+ # Block all simplices that contain vertex 6
+ simplex.index(6)
+ print(simplex, " is blocked")
+ return True
+ except ValueError:
+ print(simplex, " is accepted")
+ st.assign_filtration(simplex, st.filtration(simplex) + 1.0)
+ return False
+
+ st.expansion_with_blocker(2, blocker)
+ assert st.num_simplices() == 22
+ assert st.dimension() == 2
+ assert st.find([4, 5, 6]) == False
+ assert st.filtration([0, 1, 2]) == 4.0
+ assert st.filtration([0, 1, 3]) == 5.0
+ assert st.filtration([0, 2, 3]) == 6.0
+ assert st.filtration([1, 2, 3]) == 6.0
+
+ st.expansion_with_blocker(3, blocker)
+ assert st.num_simplices() == 23
+ assert st.dimension() == 3
+ assert st.find([4, 5, 6]) == False
+ assert st.filtration([0, 1, 2]) == 4.0
+ assert st.filtration([0, 1, 3]) == 5.0
+ assert st.filtration([0, 2, 3]) == 6.0
+ assert st.filtration([1, 2, 3]) == 6.0
+ assert st.filtration([0, 1, 2, 3]) == 7.0
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
new file mode 100644
index 00000000..1c05a215
--- /dev/null
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -0,0 +1,59 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.sklearn.cubical_persistence import CubicalPersistence
+import numpy as np
+from sklearn import datasets
+
+CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0.0, 6.0], [0.0, 8.0], [0.0, np.inf]])
+
+
+def test_simple_constructor_from_top_cells():
+ cells = datasets.load_digits().images[0]
+ cp = CubicalPersistence(homology_dimensions=0)
+ np.testing.assert_array_equal(cp._CubicalPersistence__transform_only_this_dim(cells), CUBICAL_PERSISTENCE_H0_IMG0)
+ cp = CubicalPersistence(homology_dimensions=[0, 2])
+ diags = cp._CubicalPersistence__transform(cells)
+ assert len(diags) == 2
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+
+def test_simple_constructor_from_top_cells_list():
+ digits = datasets.load_digits().images[:10]
+ cp = CubicalPersistence(homology_dimensions=0, n_jobs=-2)
+
+ diags = cp.fit_transform(digits)
+ assert len(diags) == 10
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+ cp = CubicalPersistence(homology_dimensions=[0, 1], n_jobs=-1)
+ diagsH0H1 = cp.fit_transform(digits)
+ assert len(diagsH0H1) == 10
+ for idx in range(10):
+ np.testing.assert_array_equal(diags[idx], diagsH0H1[idx][0])
+
+def test_simple_constructor_from_flattened_cells():
+ cells = datasets.load_digits().images[0]
+ # Not squared (extended) flatten cells
+ flat_cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten()
+
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
+ diags = cp.fit_transform([flat_cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+ # Not squared (extended) non-flatten cells
+ cells = np.hstack((cells, np.zeros((cells.shape[0], 2))))
+
+ # The aim of this second part of the test is to resize even if not mandatory
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
+ diags = cp.fit_transform([cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
diff --git a/src/python/test/test_subsampling.py b/src/python/test/test_subsampling.py
index 4019852e..c1cb4e3f 100755
--- a/src/python/test/test_subsampling.py
+++ b/src/python/test/test_subsampling.py
@@ -16,17 +16,9 @@ __license__ = "MIT"
def test_write_off_file_for_tests():
- file = open("subsample.off", "w")
- file.write("nOFF\n")
- file.write("2 7 0 0\n")
- file.write("1.0 1.0\n")
- file.write("7.0 0.0\n")
- file.write("4.0 6.0\n")
- file.write("9.0 6.0\n")
- file.write("0.0 14.0\n")
- file.write("2.0 19.0\n")
- file.write("9.0 17.0\n")
- file.close()
+ gudhi.write_points_to_off_file(
+ "subsample.off", [[1.0, 1.0], [7.0, 0.0], [4.0, 6.0], [9.0, 6.0], [0.0, 14.0], [2.0, 19.0], [9.0, 17.0]]
+ )
def test_simple_choose_n_farthest_points_with_a_starting_point():
@@ -34,54 +26,29 @@ def test_simple_choose_n_farthest_points_with_a_starting_point():
i = 0
for point in point_set:
# The iteration starts with the given starting point
- sub_set = gudhi.choose_n_farthest_points(
- points=point_set, nb_points=1, starting_point=i
- )
+ sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=1, starting_point=i)
assert sub_set[0] == point_set[i]
i = i + 1
# The iteration finds then the farthest
- sub_set = gudhi.choose_n_farthest_points(
- points=point_set, nb_points=2, starting_point=1
- )
+ sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=2, starting_point=1)
assert sub_set[1] == point_set[3]
- sub_set = gudhi.choose_n_farthest_points(
- points=point_set, nb_points=2, starting_point=3
- )
+ sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=2, starting_point=3)
assert sub_set[1] == point_set[1]
- sub_set = gudhi.choose_n_farthest_points(
- points=point_set, nb_points=2, starting_point=0
- )
+ sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=2, starting_point=0)
assert sub_set[1] == point_set[2]
- sub_set = gudhi.choose_n_farthest_points(
- points=point_set, nb_points=2, starting_point=2
- )
+ sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=2, starting_point=2)
assert sub_set[1] == point_set[0]
# Test the limits
- assert (
- gudhi.choose_n_farthest_points(points=[], nb_points=0, starting_point=0) == []
- )
- assert (
- gudhi.choose_n_farthest_points(points=[], nb_points=1, starting_point=0) == []
- )
- assert (
- gudhi.choose_n_farthest_points(points=[], nb_points=0, starting_point=1) == []
- )
- assert (
- gudhi.choose_n_farthest_points(points=[], nb_points=1, starting_point=1) == []
- )
+ assert gudhi.choose_n_farthest_points(points=[], nb_points=0, starting_point=0) == []
+ assert gudhi.choose_n_farthest_points(points=[], nb_points=1, starting_point=0) == []
+ assert gudhi.choose_n_farthest_points(points=[], nb_points=0, starting_point=1) == []
+ assert gudhi.choose_n_farthest_points(points=[], nb_points=1, starting_point=1) == []
# From off file test
for i in range(0, 7):
- assert (
- len(
- gudhi.choose_n_farthest_points(
- off_file="subsample.off", nb_points=i, starting_point=i
- )
- )
- == i
- )
+ assert len(gudhi.choose_n_farthest_points(off_file="subsample.off", nb_points=i, starting_point=i)) == i
def test_simple_choose_n_farthest_points_randomed():
@@ -91,7 +58,7 @@ def test_simple_choose_n_farthest_points_randomed():
assert gudhi.choose_n_farthest_points(points=[], nb_points=1) == []
assert gudhi.choose_n_farthest_points(points=point_set, nb_points=0) == []
- # Go furter than point set on purpose
+ # Go further than point set on purpose
for iter in range(1, 10):
sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=iter)
for sub in sub_set:
@@ -104,10 +71,7 @@ def test_simple_choose_n_farthest_points_randomed():
# From off file test
for i in range(0, 7):
- assert (
- len(gudhi.choose_n_farthest_points(off_file="subsample.off", nb_points=i))
- == i
- )
+ assert len(gudhi.choose_n_farthest_points(off_file="subsample.off", nb_points=i)) == i
def test_simple_pick_n_random_points():
@@ -117,7 +81,7 @@ def test_simple_pick_n_random_points():
assert gudhi.pick_n_random_points(points=[], nb_points=1) == []
assert gudhi.pick_n_random_points(points=point_set, nb_points=0) == []
- # Go furter than point set on purpose
+ # Go further than point set on purpose
for iter in range(1, 10):
sub_set = gudhi.pick_n_random_points(points=point_set, nb_points=iter)
for sub in sub_set:
@@ -130,9 +94,7 @@ def test_simple_pick_n_random_points():
# From off file test
for i in range(0, 7):
- assert (
- len(gudhi.pick_n_random_points(off_file="subsample.off", nb_points=i)) == i
- )
+ assert len(gudhi.pick_n_random_points(off_file="subsample.off", nb_points=i)) == i
def test_simple_sparsify_points():
@@ -152,31 +114,10 @@ def test_simple_sparsify_points():
]
assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=2.001) == [[0, 1]]
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=0.0))
- == 7
- )
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=30.0))
- == 5
- )
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=40.1))
- == 4
- )
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=89.9))
- == 3
- )
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=100.0))
- == 2
- )
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=324.9))
- == 2
- )
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=325.01))
- == 1
- )
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=0.0)) == 7
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=30.0)) == 5
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=40.1)) == 4
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=89.9)) == 3
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=100.0)) == 2
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=324.9)) == 2
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=325.01)) == 1
diff --git a/src/python/test/test_tomato.py b/src/python/test/test_tomato.py
index ecab03c4..c571f799 100755
--- a/src/python/test/test_tomato.py
+++ b/src/python/test/test_tomato.py
@@ -37,7 +37,7 @@ def test_tomato_1():
t = Tomato(metric="euclidean", graph_type="radius", r=4.7, k=4)
t.fit(a)
assert t.max_weight_per_cc_.size == 2
- assert np.array_equal(t.neighbors_, [[0, 1, 2], [0, 1, 2], [0, 1, 2], [3, 4, 5, 6], [3, 4, 5], [3, 4, 5], [3, 6]])
+ assert t.neighbors_ == [[0, 1, 2], [0, 1, 2], [0, 1, 2], [3, 4, 5, 6], [3, 4, 5], [3, 4, 5], [3, 6]]
t.plot_diagram()
t = Tomato(graph_type="radius", r=4.7, k=4, symmetrize_graph=True)
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index e3b521d6..42bf3299 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -5,25 +5,105 @@
Copyright (C) 2019 Inria
Modification(s):
+ - 2020/07 Théo Lacombe: Added tests about handling essential parts in diagrams.
- YYYY/MM Author: Description of the modification
"""
-from gudhi.wasserstein.wasserstein import _proj_on_diag
+from gudhi.wasserstein.wasserstein import _proj_on_diag, _finite_part, _handle_essential_parts, _get_essential_parts
+from gudhi.wasserstein.wasserstein import _warn_infty
from gudhi.wasserstein import wasserstein_distance as pot
from gudhi.hera import wasserstein_distance as hera
import numpy as np
import pytest
+
__author__ = "Theo Lacombe"
__copyright__ = "Copyright (C) 2019 Inria"
__license__ = "MIT"
+
def test_proj_on_diag():
dgm = np.array([[1., 1.], [1., 2.], [3., 5.]])
assert np.array_equal(_proj_on_diag(dgm), [[1., 1.], [1.5, 1.5], [4., 4.]])
empty = np.empty((0, 2))
assert np.array_equal(_proj_on_diag(empty), empty)
+
+def test_finite_part():
+ diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf],
+ [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]])
+ assert np.array_equal(_finite_part(diag), [[0, 1], [3, 5]])
+
+
+def test_handle_essential_parts():
+ diag1 = np.array([[0, 1], [3, 5],
+ [2, np.inf], [3, np.inf],
+ [-np.inf, 8], [-np.inf, 12],
+ [-np.inf, -np.inf],
+ [np.inf, np.inf],
+ [-np.inf, np.inf], [-np.inf, np.inf]])
+
+ diag2 = np.array([[0, 2], [3, 5],
+ [2, np.inf], [4, np.inf],
+ [-np.inf, 8], [-np.inf, 11],
+ [-np.inf, -np.inf],
+ [np.inf, np.inf],
+ [-np.inf, np.inf], [-np.inf, np.inf]])
+
+ diag3 = np.array([[0, 2], [3, 5],
+ [2, np.inf], [4, np.inf], [6, np.inf],
+ [-np.inf, 8], [-np.inf, 11],
+ [-np.inf, -np.inf],
+ [np.inf, np.inf],
+ [-np.inf, np.inf], [-np.inf, np.inf]])
+
+ c, m = _handle_essential_parts(diag1, diag2, order=1)
+ assert c == pytest.approx(2, 0.0001) # Note: here c is only the cost due to essential part (thus 2, not 3)
+ # Similarly, the matching only corresponds to essential parts.
+ # Note that (-inf,-inf) and (+inf,+inf) coordinates are matched to the diagonal.
+ assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, -1], [7, -1], [-1, 6], [-1, 7]])
+
+ c, m = _handle_essential_parts(diag1, diag3, order=1)
+ assert c == np.inf
+ assert (m is None)
+
+
+def test_get_essential_parts():
+ diag1 = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf],
+ [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]])
+
+ diag2 = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf]])
+
+ res = _get_essential_parts(diag1)
+ res2 = _get_essential_parts(diag2)
+ assert np.array_equal(res[0], [4, 5])
+ assert np.array_equal(res[1], [2, 3])
+ assert np.array_equal(res[2], [8, 9])
+ assert np.array_equal(res[3], [6] )
+ assert np.array_equal(res[4], [7] )
+
+ assert np.array_equal(res2[0], [] )
+ assert np.array_equal(res2[1], [2, 3])
+ assert np.array_equal(res2[2], [] )
+ assert np.array_equal(res2[3], [] )
+ assert np.array_equal(res2[4], [] )
+
+
+def test_warn_infty():
+ with pytest.warns(UserWarning):
+ assert _warn_infty(matching=False)==np.inf
+ c, m = _warn_infty(matching=True)
+ assert (c == np.inf)
+ assert (m is None)
+
+
+def _to_set(X):
+ return { (i, j) for i, j in X }
+
+def _same_permuted(X, Y):
+ return _to_set(X) == _to_set(Y)
+
+
def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):
diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])
@@ -64,20 +144,46 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
assert wasserstein_distance(diag4, diag5) == np.inf
assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.)
-
+ assert wasserstein_distance(diag5, emptydiag) == np.inf
if test_matching:
match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1]
- assert np.array_equal(match, [])
+ # Accept [] or np.array of shape (2, 0)
+ assert len(match) == 0
match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
- assert np.array_equal(match, [])
+ assert len(match) == 0
match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1]
- assert np.array_equal(match , [[-1, 0], [-1, 1]])
+ assert _same_permuted(match, [[-1, 0], [-1, 1]])
match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
- assert np.array_equal(match , [[0, -1], [1, -1]])
+ assert _same_permuted(match, [[0, -1], [1, -1]])
match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
- assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]])
-
+ assert _same_permuted(match, [[0, 0], [1, 1], [2, -1]])
+
+ if test_matching and test_infinity:
+ diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]])
+ diag8 = np.array([[0,1], [0, np.inf], [-np.inf, -np.inf], [np.inf, np.inf]])
+ diag9 = np.array([[-np.inf, -np.inf], [np.inf, np.inf]])
+ diag10 = np.array([[0,1], [-np.inf, -np.inf], [np.inf, np.inf]])
+
+ match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1]
+ assert _same_permuted(match, [[0, -1], [-1,0], [-1, 1], [1, 2]])
+ match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1]
+ assert (match is None)
+ cost, match = wasserstein_distance(diag7, emptydiag, matching=True, internal_p=2., order=2.3)
+ assert (cost == np.inf)
+ assert (match is None)
+ cost, match = wasserstein_distance(emptydiag, diag7, matching=True, internal_p=2.42, order=2.)
+ assert (cost == np.inf)
+ assert (match is None)
+ cost, match = wasserstein_distance(diag8, diag9, matching=True, internal_p=2., order=2.)
+ assert (cost == np.inf)
+ assert (match is None)
+ cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.)
+ assert (cost == 1)
+ assert _same_permuted(match, [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway.
+ cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.)
+ assert (cost == 0.)
+ assert _same_permuted(match, [[0, -1], [1, -1]])
def hera_wrap(**extra):
@@ -85,15 +191,19 @@ def hera_wrap(**extra):
return hera(*kargs,**kwargs,**extra)
return fun
+
def pot_wrap(**extra):
def fun(*kargs,**kwargs):
return pot(*kargs,**kwargs,**extra)
return fun
+
def test_wasserstein_distance_pot():
- _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True)
- _basic_wasserstein(pot_wrap(enable_autodiff=True), 1e-15, test_infinity=False, test_matching=False)
+ _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) # pot with its standard args
+ _basic_wasserstein(pot_wrap(enable_autodiff=True, keep_essential_parts=False), 1e-15, test_infinity=False, test_matching=False)
+
def test_wasserstein_distance_hera():
- _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
- _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)
+ _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=True)
+ _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=True)
+