diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2022-05-17 09:18:18 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2022-05-17 09:18:18 +0200 |
commit | e45fd615787821c40446f68b576e9e4ff8cb1a27 (patch) | |
tree | aab80cd4a82f1a9b760e1114b9a3fdbf11a8da77 /src/python | |
parent | e495c4de49a9a226359fbf21966f4ebc4b3fc31b (diff) | |
parent | 91a95c2709d293c31e5dc64fd4f1b8d370513605 (diff) |
Merge remote-tracking branch 'origin/master' into collapse-pr2
Diffstat (limited to 'src/python')
-rw-r--r-- | src/python/CMakeLists.txt | 16 | ||||
-rw-r--r-- | src/python/doc/alpha_complex_user.rst | 3 | ||||
-rw-r--r-- | src/python/doc/persistence_graphical_tools_user.rst | 2 | ||||
-rw-r--r-- | src/python/gudhi/alpha_complex.pyx | 29 | ||||
-rw-r--r-- | src/python/gudhi/persistence_graphical_tools.py | 349 | ||||
-rw-r--r-- | src/python/gudhi/simplex_tree.pxd | 4 | ||||
-rw-r--r-- | src/python/gudhi/simplex_tree.pyx | 64 | ||||
-rw-r--r-- | src/python/include/Alpha_complex_interface.h | 10 | ||||
-rw-r--r-- | src/python/include/Simplex_tree_interface.h | 8 | ||||
-rwxr-xr-x | src/python/test/test_alpha_complex.py | 27 | ||||
-rw-r--r-- | src/python/test/test_persistence_graphical_tools.py | 121 | ||||
-rwxr-xr-x | src/python/test/test_simplex_tree.py | 101 |
12 files changed, 550 insertions, 184 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 0b51df41..54221151 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -176,6 +176,15 @@ if(PYTHONINTERP_FOUND) set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'alpha_complex', ") endif () + # from windows vcpkg eigen 3.4.0#2 : build fails with + # error C2440: '<function-style-cast>': cannot convert from 'Eigen::EigenBase<Derived>::Index' to '__gmp_expr<mpq_t,mpq_t>' + # cf. https://gitlab.com/libeigen/eigen/-/issues/2476 + # Workaround is to compile with '-DEIGEN_DEFAULT_DENSE_INDEX_TYPE=int' + if (FORCE_EIGEN_DEFAULT_DENSE_INDEX_TYPE_TO_INT) + set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DEIGEN_DEFAULT_DENSE_INDEX_TYPE=int', ") + endif() + + add_gudhi_debug_info("Boost version ${Boost_VERSION}") if(CGAL_FOUND) if(NOT CGAL_VERSION VERSION_LESS 5.3.0) @@ -211,13 +220,14 @@ if(PYTHONINTERP_FOUND) endif(NOT GMP_LIBRARIES_DIR) add_GUDHI_PYTHON_lib_dir(${GMP_LIBRARIES_DIR}) message("** Add gmp ${GMP_LIBRARIES_DIR}") + # When FORCE_CGAL_NOT_TO_BUILD_WITH_GMPXX is set, not defining CGAL_USE_GMPXX is sufficient enough if(GMPXX_FOUND) add_gudhi_debug_info("GMPXX_LIBRARIES = ${GMPXX_LIBRARIES}") set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DCGAL_USE_GMPXX', ") add_GUDHI_PYTHON_lib("${GMPXX_LIBRARIES}") add_GUDHI_PYTHON_lib_dir(${GMPXX_LIBRARIES_DIR}) message("** Add gmpxx ${GMPXX_LIBRARIES_DIR}") - endif(GMPXX_FOUND) + endif() endif(GMP_FOUND) if(MPFR_FOUND) add_gudhi_debug_info("MPFR_LIBRARIES = ${MPFR_LIBRARIES}") @@ -581,6 +591,10 @@ if(PYTHONINTERP_FOUND) add_gudhi_py_test(test_dtm_rips_complex) endif() + # persistence graphical tools + if(MATPLOTLIB_FOUND) + add_gudhi_py_test(test_persistence_graphical_tools) + endif() # Set missing or not modules set(GUDHI_MODULES ${GUDHI_MODULES} "python" CACHE INTERNAL "GUDHI_MODULES") diff --git a/src/python/doc/alpha_complex_user.rst b/src/python/doc/alpha_complex_user.rst index cfd22742..b060c86e 100644 --- a/src/python/doc/alpha_complex_user.rst +++ b/src/python/doc/alpha_complex_user.rst @@ -27,7 +27,8 @@ Remarks If you pass :code:`precision = 'exact'` to :func:`~gudhi.AlphaComplex.__init__`, the filtration values are the exact ones converted to float. This can be very slow. If you pass :code:`precision = 'safe'` (the default), the filtration values are only - guaranteed to have a small multiplicative error compared to the exact value. + guaranteed to have a small multiplicative error compared to the exact value, see + :func:`~gudhi.AlphaComplex.set_float_relative_precision` to modify the precision. A drawback, when computing persistence, is that an empty exact interval [10^12,10^12] may become a non-empty approximate interval [10^12,10^12+10^6]. Using :code:`precision = 'fast'` makes the computations slightly faster, and the combinatorics are still exact, but diff --git a/src/python/doc/persistence_graphical_tools_user.rst b/src/python/doc/persistence_graphical_tools_user.rst index d95b9d2b..e1d28c71 100644 --- a/src/python/doc/persistence_graphical_tools_user.rst +++ b/src/python/doc/persistence_graphical_tools_user.rst @@ -60,7 +60,7 @@ of shape (N x 2) encoding a persistence diagram (in a given dimension). import matplotlib.pyplot as plt import gudhi import numpy as np - d = np.array([[0, 1], [1, 2], [1, np.inf]]) + d = np.array([[0., 1.], [1., 2.], [1., np.inf]]) gudhi.plot_persistence_diagram(d) plt.show() diff --git a/src/python/gudhi/alpha_complex.pyx b/src/python/gudhi/alpha_complex.pyx index a4888914..375e1561 100644 --- a/src/python/gudhi/alpha_complex.pyx +++ b/src/python/gudhi/alpha_complex.pyx @@ -31,6 +31,10 @@ cdef extern from "Alpha_complex_interface.h" namespace "Gudhi": Alpha_complex_interface(vector[vector[double]] points, vector[double] weights, bool fast_version, bool exact_version) nogil except + vector[double] get_point(int vertex) nogil except + void create_simplex_tree(Simplex_tree_interface_full_featured* simplex_tree, double max_alpha_square, bool default_filtration_value) nogil except + + @staticmethod + void set_float_relative_precision(double precision) nogil + @staticmethod + double get_float_relative_precision() nogil # AlphaComplex python interface cdef class AlphaComplex: @@ -133,3 +137,28 @@ cdef class AlphaComplex: self.this_ptr.create_simplex_tree(<Simplex_tree_interface_full_featured*>stree_int_ptr, mas, compute_filtration) return stree + + @staticmethod + def set_float_relative_precision(precision): + """ + :param precision: When the AlphaComplex is constructed with :code:`precision = 'safe'` (the default), + one can set the float relative precision of filtration values computed in + :func:`~gudhi.AlphaComplex.create_simplex_tree`. + Default is :code:`1e-5` (cf. :func:`~gudhi.AlphaComplex.get_float_relative_precision`). + For more details, please refer to + `CGAL::Lazy_exact_nt<NT>::set_relative_precision_of_to_double <https://doc.cgal.org/latest/Number_types/classCGAL_1_1Lazy__exact__nt.html>`_ + :type precision: float + """ + if precision <=0. or precision >= 1.: + raise ValueError("Relative precision value must be strictly greater than 0 and strictly lower than 1") + Alpha_complex_interface.set_float_relative_precision(precision) + + @staticmethod + def get_float_relative_precision(): + """ + :returns: The float relative precision of filtration values computation in + :func:`~gudhi.AlphaComplex.create_simplex_tree` when the AlphaComplex is constructed with + :code:`precision = 'safe'` (the default). + :rtype: float + """ + return Alpha_complex_interface.get_float_relative_precision() diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 848dc03e..7ed11360 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -12,6 +12,9 @@ from os import path from math import isfinite import numpy as np from functools import lru_cache +import warnings +import errno +import os from gudhi.reader_utils import read_persistence_intervals_in_dimension from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension @@ -22,6 +25,7 @@ __license__ = "MIT" _gudhi_matplotlib_use_tex = True + def __min_birth_max_death(persistence, band=0.0): """This function returns (min_birth, max_death) from the persistence. @@ -44,20 +48,46 @@ def __min_birth_max_death(persistence, band=0.0): min_birth = float(interval[1][0]) if band > 0.0: max_death += band + # can happen if only points at inf death + if min_birth == max_death: + max_death = max_death + 1.0 return (min_birth, max_death) def _array_handler(a): - ''' + """ :param a: if array, assumes it is a (n x 2) np.array and return a persistence-compatible list (padding with 0), so that the plot can be performed seamlessly. - ''' - if isinstance(a[0][1], np.float64) or isinstance(a[0][1], float): + """ + if isinstance(a[0][1], (np.floating, float)): return [[0, x] for x in a] else: return a + +def _limit_to_max_intervals(persistence, max_intervals, key): + """This function returns truncated persistence if length is bigger than max_intervals. + :param persistence: Persistence intervals values list. Can be grouped by dimension or not. + :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death). + :param max_intervals: maximal number of intervals to display. + Selected intervals are those with the longest life time. Set it + to 0 to see all. Default value is 1000. + :type max_intervals: int. + :param key: key function for sort algorithm. + :type key: function or lambda. + """ + if max_intervals > 0 and max_intervals < len(persistence): + warnings.warn( + "There are %s intervals given as input, whereas max_intervals is set to %s." + % (len(persistence), max_intervals) + ) + # Sort by life time, then takes only the max_intervals elements + return sorted(persistence, key=key, reverse=True)[:max_intervals] + else: + return persistence + + @lru_cache(maxsize=1) def _matplotlib_can_use_tex(): """This function returns True if matplotlib can deal with LaTeX, False otherwise. @@ -65,17 +95,17 @@ def _matplotlib_can_use_tex(): """ try: from matplotlib import checkdep_usetex + return checkdep_usetex(True) - except ImportError: - print("This function is not available, you may be missing matplotlib.") + except ImportError as import_error: + warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.") def plot_persistence_barcode( persistence=[], persistence_file="", alpha=0.6, - max_intervals=1000, - max_barcodes=1000, + max_intervals=20000, inf_delta=0.1, legend=False, colormap=None, @@ -97,7 +127,7 @@ def plot_persistence_barcode( :type alpha: float. :param max_intervals: maximal number of intervals to display. Selected intervals are those with the longest life time. Set it - to 0 to see all. Default value is 1000. + to 0 to see all. Default value is 20000. :type max_intervals: int. :param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x inf_delta)` above :code:`max_death` value. A reasonable value is @@ -119,99 +149,68 @@ def plot_persistence_barcode( import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib import rc + if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex(): - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + plt.rc("text", usetex=True) + plt.rc("font", family="serif") else: - plt.rc('text', usetex=False) - plt.rc('font', family='DejaVu Sans') + plt.rc("text", usetex=False) + plt.rc("font", family="DejaVu Sans") if persistence_file != "": if path.isfile(persistence_file): # Reset persistence persistence = [] - diag = read_persistence_intervals_grouped_by_dimension( - persistence_file=persistence_file - ) + diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) for key in diag.keys(): for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) else: - print("file " + persistence_file + " not found.") - return None - - persistence = _array_handler(persistence) - - if max_barcodes != 1000: - print("Deprecated parameter. It has been replaced by max_intervals") - max_intervals = max_barcodes - - if max_intervals > 0 and max_intervals < len(persistence): - # Sort by life time, then takes only the max_intervals elements - persistence = sorted( - persistence, - key=lambda life_time: life_time[1][1] - life_time[1][0], - reverse=True, - )[:max_intervals] - - if colormap == None: - colormap = plt.cm.Set1.colors - if axes == None: - fig, axes = plt.subplots(1, 1) + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file) - persistence = sorted(persistence, key=lambda birth: birth[1][0]) + try: + persistence = _array_handler(persistence) + persistence = _limit_to_max_intervals( + persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) + (min_birth, max_death) = __min_birth_max_death(persistence) + persistence = sorted(persistence, key=lambda birth: birth[1][0]) + except IndexError: + min_birth, max_death = 0.0, 1.0 + pass - (min_birth, max_death) = __min_birth_max_death(persistence) - ind = 0 delta = (max_death - min_birth) * inf_delta # Replace infinity values with max_death + delta for bar code to be more # readable infinity = max_death + delta axis_start = min_birth - delta - # Draw horizontal bars in loop - for interval in reversed(persistence): - if float(interval[1][1]) != float("inf"): - # Finite death case - axes.barh( - ind, - (interval[1][1] - interval[1][0]), - height=0.8, - left=interval[1][0], - alpha=alpha, - color=colormap[interval[0]], - linewidth=0, - ) - else: - # Infinite death case for diagram to be nicer - axes.barh( - ind, - (infinity - interval[1][0]), - height=0.8, - left=interval[1][0], - alpha=alpha, - color=colormap[interval[0]], - linewidth=0, - ) - ind = ind + 1 + + if axes == None: + _, axes = plt.subplots(1, 1) + if colormap == None: + colormap = plt.cm.Set1.colors + + x=[birth for (dim,(birth,death)) in persistence] + y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence] + c=[colormap[dim] for (dim,(birth,death)) in persistence] + + axes.barh(list(reversed(range(len(x)))), y, height=0.8, left=x, alpha=alpha, color=c, linewidth=0) if legend: dimensions = list(set(item[0] for item in persistence)) axes.legend( - handles=[ - mpatches.Patch(color=colormap[dim], label=str(dim)) - for dim in dimensions - ], - loc="lower right", + handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right", ) axes.set_title("Persistence barcode", fontsize=fontsize) # Ends plot on infinity value and starts a little bit before min_birth - axes.axis([axis_start, infinity, 0, ind]) + if len(x) != 0: + axes.axis([axis_start, infinity, 0, len(x)]) return axes - except ImportError: - print("This function is not available, you may be missing matplotlib.") + except ImportError as import_error: + warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.") def plot_persistence_diagram( @@ -219,14 +218,13 @@ def plot_persistence_diagram( persistence_file="", alpha=0.6, band=0.0, - max_intervals=1000, - max_plots=1000, + max_intervals=1000000, inf_delta=0.1, legend=False, colormap=None, axes=None, fontsize=16, - greyblock=True + greyblock=True, ): """This function plots the persistence diagram from persistence values list, a np.array of shape (N x 2) representing a diagram in a single @@ -244,7 +242,7 @@ def plot_persistence_diagram( :type band: float. :param max_intervals: maximal number of intervals to display. Selected intervals are those with the longest life time. Set it - to 0 to see all. Default value is 1000. + to 0 to see all. Default value is 1000000. :type max_intervals: int. :param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x inf_delta)` above :code:`max_death` value. A reasonable value is @@ -268,47 +266,35 @@ def plot_persistence_diagram( import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib import rc + if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex(): - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + plt.rc("text", usetex=True) + plt.rc("font", family="serif") else: - plt.rc('text', usetex=False) - plt.rc('font', family='DejaVu Sans') + plt.rc("text", usetex=False) + plt.rc("font", family="DejaVu Sans") if persistence_file != "": if path.isfile(persistence_file): # Reset persistence persistence = [] - diag = read_persistence_intervals_grouped_by_dimension( - persistence_file=persistence_file - ) + diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) for key in diag.keys(): for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) else: - print("file " + persistence_file + " not found.") - return None - - persistence = _array_handler(persistence) + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file) - if max_plots != 1000: - print("Deprecated parameter. It has been replaced by max_intervals") - max_intervals = max_plots - - if max_intervals > 0 and max_intervals < len(persistence): - # Sort by life time, then takes only the max_intervals elements - persistence = sorted( - persistence, - key=lambda life_time: life_time[1][1] - life_time[1][0], - reverse=True, - )[:max_intervals] - - if colormap == None: - colormap = plt.cm.Set1.colors - if axes == None: - fig, axes = plt.subplots(1, 1) + try: + persistence = _array_handler(persistence) + persistence = _limit_to_max_intervals( + persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) + min_birth, max_death = __min_birth_max_death(persistence, band) + except IndexError: + min_birth, max_death = 0.0, 1.0 + pass - (min_birth, max_death) = __min_birth_max_death(persistence, band) delta = (max_death - min_birth) * inf_delta # Replace infinity values with max_death + delta for diagram to be more # readable @@ -316,61 +302,56 @@ def plot_persistence_diagram( axis_end = max_death + delta / 2 axis_start = min_birth - delta + if axes == None: + _, axes = plt.subplots(1, 1) + if colormap == None: + colormap = plt.cm.Set1.colors # bootstrap band if band > 0.0: x = np.linspace(axis_start, infinity, 1000) axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red") # lower diag patch if greyblock: - axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey')) - # Draw points in loop - pts_at_infty = False # Records presence of pts at infty - for interval in reversed(persistence): - if float(interval[1][1]) != float("inf"): - # Finite death case - axes.scatter( - interval[1][0], - interval[1][1], - alpha=alpha, - color=colormap[interval[0]], + axes.add_patch( + mpatches.Polygon( + [[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], + fill=True, + color="lightgrey", ) - else: - pts_at_infty = True - # Infinite death case for diagram to be nicer - axes.scatter( - interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]] - ) - if pts_at_infty: + ) + # line display of equation : birth = death + axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") + + x=[birth for (dim,(birth,death)) in persistence] + y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence] + c=[colormap[dim] for (dim,(birth,death)) in persistence] + + axes.scatter(x,y,alpha=alpha,color=c) + if float("inf") in (death for (dim,(birth,death)) in persistence): # infinity line and text - axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha) # Infinity label yt = axes.get_yticks() - yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity + yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity yt = np.append(yt, infinity) ytl = ["%.3f" % e for e in yt] # to avoid float precision error - ytl[-1] = r'$+\infty$' + ytl[-1] = r"$+\infty$" axes.set_yticks(yt) axes.set_yticklabels(ytl) if legend: dimensions = list(set(item[0] for item in persistence)) - axes.legend( - handles=[ - mpatches.Patch(color=colormap[dim], label=str(dim)) - for dim in dimensions - ] - ) + axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions]) axes.set_xlabel("Birth", fontsize=fontsize) axes.set_ylabel("Death", fontsize=fontsize) axes.set_title("Persistence diagram", fontsize=fontsize) # Ends plot on infinity value and starts a little bit before min_birth - axes.axis([axis_start, axis_end, axis_start, infinity + delta/2]) + axes.axis([axis_start, axis_end, axis_start, infinity + delta / 2]) return axes - except ImportError: - print("This function is not available, you may be missing matplotlib.") + except ImportError as import_error: + warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.") def plot_persistence_density( @@ -384,7 +365,7 @@ def plot_persistence_density( legend=False, axes=None, fontsize=16, - greyblock=False + greyblock=False, ): """This function plots the persistence density from persistence values list, np.array of shape (N x 2) representing a diagram @@ -444,12 +425,13 @@ def plot_persistence_density( import matplotlib.patches as mpatches from scipy.stats import kde from matplotlib import rc + if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex(): - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + plt.rc("text", usetex=True) + plt.rc("font", family="serif") else: - plt.rc('text', usetex=False) - plt.rc('font', family='DejaVu Sans') + plt.rc("text", usetex=False) + plt.rc("font", family="DejaVu Sans") if persistence_file != "": if dimension is None: @@ -460,10 +442,16 @@ def plot_persistence_density( persistence_file=persistence_file, only_this_dim=dimension ) else: - print("file " + persistence_file + " not found.") - return None + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file) + + # default cmap value cannot be done at argument definition level as matplotlib is not yet defined. + if cmap is None: + cmap = plt.cm.hot_r + if axes == None: + _, axes = plt.subplots(1, 1) - if len(persistence) > 0: + try: + # if not read from file but given by an argument persistence = _array_handler(persistence) persistence_dim = np.array( [ @@ -472,47 +460,54 @@ def plot_persistence_density( if (dim_interval[0] == dimension) or (dimension is None) ] ) - - persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])] - if max_intervals > 0 and max_intervals < len(persistence_dim): - # Sort by life time, then takes only the max_intervals elements + persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])] persistence_dim = np.array( - sorted( - persistence_dim, - key=lambda life_time: life_time[1] - life_time[0], - reverse=True, - )[:max_intervals] + _limit_to_max_intervals( + persistence_dim, max_intervals, key=lambda life_time: life_time[1] - life_time[0] + ) ) - # Set as numpy array birth and death (remove undefined values - inf and NaN) - birth = persistence_dim[:, 0] - death = persistence_dim[:, 1] - - # default cmap value cannot be done at argument definition level as matplotlib is not yet defined. - if cmap is None: - cmap = plt.cm.hot_r - if axes == None: - fig, axes = plt.subplots(1, 1) + # Set as numpy array birth and death (remove undefined values - inf and NaN) + birth = persistence_dim[:, 0] + death = persistence_dim[:, 1] + birth_min = birth.min() + birth_max = birth.max() + death_min = death.min() + death_max = death.max() + + # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents + k = kde.gaussian_kde([birth, death], bw_method=bw_method) + xi, yi = np.mgrid[ + birth_min : birth_max : nbins * 1j, death_min : death_max : nbins * 1j, + ] + zi = k(np.vstack([xi.flatten(), yi.flatten()])) + # Make the plot + img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto") + plot_success = True + + # IndexError on empty diagrams, ValueError on only inf death values + except (IndexError, ValueError): + birth_min = 0.0 + birth_max = 1.0 + death_min = 0.0 + death_max = 1.0 + plot_success = False + pass # line display of equation : birth = death - x = np.linspace(death.min(), birth.max(), 1000) + x = np.linspace(death_min, birth_max, 1000) axes.plot(x, x, color="k", linewidth=1.0) - # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents - k = kde.gaussian_kde([birth, death], bw_method=bw_method) - xi, yi = np.mgrid[ - birth.min() : birth.max() : nbins * 1j, - death.min() : death.max() : nbins * 1j, - ] - zi = k(np.vstack([xi.flatten(), yi.flatten()])) - - # Make the plot - img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap) - if greyblock: - axes.add_patch(mpatches.Polygon([[birth.min(), birth.min()], [death.max(), birth.min()], [death.max(), death.max()]], fill=True, color='lightgrey')) + axes.add_patch( + mpatches.Polygon( + [[birth_min, birth_min], [death_max, birth_min], [death_max, death_max]], + fill=True, + color="lightgrey", + ) + ) - if legend: + if plot_success and legend: plt.colorbar(img, ax=axes) axes.set_xlabel("Birth", fontsize=fontsize) @@ -521,7 +516,5 @@ def plot_persistence_density( return axes - except ImportError: - print( - "This function is not available, you may be missing matplotlib and/or scipy." - ) + except ImportError as import_error: + warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.") diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd index 70311ead..4f229663 100644 --- a/src/python/gudhi/simplex_tree.pxd +++ b/src/python/gudhi/simplex_tree.pxd @@ -45,6 +45,7 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": cdef cppclass Simplex_tree_interface_full_featured "Gudhi::Simplex_tree_interface<Gudhi::Simplex_tree_options_full_featured>": Simplex_tree_interface_full_featured() nogil + Simplex_tree_interface_full_featured(Simplex_tree_interface_full_featured&) nogil double simplex_filtration(vector[int] simplex) nogil void assign_simplex_filtration(vector[int] simplex, double filtration) nogil void initialize_filtration() nogil @@ -75,6 +76,9 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": Simplex_tree_skeleton_iterator get_skeleton_iterator_begin(int dimension) nogil Simplex_tree_skeleton_iterator get_skeleton_iterator_end(int dimension) nogil pair[Simplex_tree_boundary_iterator, Simplex_tree_boundary_iterator] get_boundary_iterators(vector[int] simplex) nogil except + + # Expansion with blockers + ctypedef bool (*blocker_func_t)(vector[int], void *user_data) + void expansion_with_blockers_callback(int dimension, blocker_func_t user_func, void *user_data) cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi": cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface<Gudhi::Simplex_tree<Gudhi::Simplex_tree_options_full_featured>>": diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 3d7c28b6..2c53a872 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -16,6 +16,9 @@ __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2016 Inria" __license__ = "MIT" +cdef bool callback(vector[int] simplex, void *blocker_func): + return (<object>blocker_func)(simplex) + # SimplexTree python interface cdef class SimplexTree: """The simplex tree is an efficient and flexible data structure for @@ -38,13 +41,29 @@ cdef class SimplexTree: cdef Simplex_tree_persistence_interface * pcohptr # Fake constructor that does nothing but documenting the constructor - def __init__(self): + def __init__(self, other = None): """SimplexTree constructor. + + :param other: If `other` is `None` (default value), an empty `SimplexTree` is created. + If `other` is a `SimplexTree`, the `SimplexTree` is constructed from a deep copy of `other`. + :type other: SimplexTree (Optional) + :returns: An empty or a copy simplex tree. + :rtype: SimplexTree + + :raises TypeError: In case `other` is neither `None`, nor a `SimplexTree`. + :note: If the `SimplexTree` is a copy, the persistence information is not copied. If you need it in the clone, + you have to call :func:`compute_persistence` on it even if you had already computed it in the original. """ # The real cython constructor - def __cinit__(self): - self.thisptr = <intptr_t>(new Simplex_tree_interface_full_featured()) + def __cinit__(self, other = None): + if other: + if isinstance(other, SimplexTree): + self.thisptr = _get_copy_intptr(other) + else: + raise TypeError("`other` argument requires to be of type `SimplexTree`, or `None`.") + else: + self.thisptr = <intptr_t>(new Simplex_tree_interface_full_featured()) def __dealloc__(self): cdef Simplex_tree_interface_full_featured* ptr = self.get_ptr() @@ -63,6 +82,21 @@ cdef class SimplexTree: """ return self.pcohptr != NULL + def copy(self): + """ + :returns: A simplex tree that is a deep copy of itself. + :rtype: SimplexTree + + :note: The persistence information is not copied. If you need it in the clone, you have to call + :func:`compute_persistence` on it even if you had already computed it in the original. + """ + stree = SimplexTree() + stree.thisptr = _get_copy_intptr(self) + return stree + + def __deepcopy__(self): + return self.copy() + def filtration(self, simplex): """This function returns the filtration value for a given N-simplex in this simplicial complex, or +infinity if it is not in the complex. @@ -443,6 +477,27 @@ cdef class SimplexTree: persistence_result = self.pcohptr.get_persistence() return self.get_ptr().compute_extended_persistence_subdiagrams(persistence_result, min_persistence) + def expansion_with_blocker(self, max_dim, blocker_func): + """Expands the Simplex_tree containing only a graph. Simplices corresponding to cliques in the graph are added + incrementally, faces before cofaces, unless the simplex has dimension larger than `max_dim` or `blocker_func` + returns `True` for this simplex. + + The function identifies a candidate simplex whose faces are all already in the complex, inserts it with a + filtration value corresponding to the maximum of the filtration values of the faces, then calls `blocker_func` + with this new simplex (represented as a list of int). If `blocker_func` returns `True`, the simplex is removed, + otherwise it is kept. The algorithm then proceeds with the next candidate. + + .. warning:: + Several candidates of the same dimension may be inserted simultaneously before calling `block_simplex`, so + if you examine the complex in `block_simplex`, you may hit a few simplices of the same dimension that have + not been vetted by `block_simplex` yet, or have already been rejected but not yet removed. + + :param max_dim: Expansion maximal dimension value. + :type max_dim: int + :param blocker_func: Blocker oracle. + :type blocker_func: Callable[[List[int]], bool] + """ + self.get_ptr().expansion_with_blockers_callback(max_dim, callback, <void*>blocker_func) def persistence(self, homology_coeff_field=11, min_persistence=0, persistence_dim_max = False): """This function computes and returns the persistence of the simplicial complex. @@ -648,3 +703,6 @@ cdef class SimplexTree: :rtype: bool """ return dereference(self.get_ptr()) == dereference(other.get_ptr()) + +cdef intptr_t _get_copy_intptr(SimplexTree stree) nogil: + return <intptr_t>(new Simplex_tree_interface_full_featured(dereference(stree.get_ptr()))) diff --git a/src/python/include/Alpha_complex_interface.h b/src/python/include/Alpha_complex_interface.h index 671af4a4..469b91ce 100644 --- a/src/python/include/Alpha_complex_interface.h +++ b/src/python/include/Alpha_complex_interface.h @@ -57,6 +57,16 @@ class Alpha_complex_interface { alpha_ptr_->create_simplex_tree(simplex_tree, max_alpha_square, default_filtration_value); } + static void set_float_relative_precision(double precision) { + // cf. Exact_alpha_complex_dD kernel type in Alpha_complex_factory.h + CGAL::Epeck_d<CGAL::Dynamic_dimension_tag>::FT::set_relative_precision_of_to_double(precision); + } + + static double get_float_relative_precision() { + // cf. Exact_alpha_complex_dD kernel type in Alpha_complex_factory.h + return CGAL::Epeck_d<CGAL::Dynamic_dimension_tag>::FT::get_relative_precision_of_to_double(); + } + private: std::unique_ptr<Abstract_alpha_complex> alpha_ptr_; }; diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h index 4d8f8537..7f9b0067 100644 --- a/src/python/include/Simplex_tree_interface.h +++ b/src/python/include/Simplex_tree_interface.h @@ -40,6 +40,7 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> { using Complex_simplex_iterator = typename Base::Complex_simplex_iterator; using Extended_filtration_data = typename Base::Extended_filtration_data; using Boundary_simplex_iterator = typename Base::Boundary_simplex_iterator; + typedef bool (*blocker_func_t)(Simplex simplex, void *user_data); public: @@ -189,6 +190,13 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> { return collapsed_stree_ptr; } + void expansion_with_blockers_callback(int dimension, blocker_func_t user_func, void *user_data) { + Base::expansion_with_blockers(dimension, [&](Simplex_handle sh){ + Simplex simplex(Base::simplex_vertex_range(sh).begin(), Base::simplex_vertex_range(sh).end()); + return user_func(simplex, user_data); + }); + } + // Iterator over the simplex tree Complex_simplex_iterator get_simplices_iterator_begin() { // this specific case works because the range is just a pair of iterators - won't work if range was a vector diff --git a/src/python/test/test_alpha_complex.py b/src/python/test/test_alpha_complex.py index f15284f3..f81e6137 100755 --- a/src/python/test/test_alpha_complex.py +++ b/src/python/test/test_alpha_complex.py @@ -286,3 +286,30 @@ def _weighted_doc_example(precision): def test_weighted_doc_example(): for precision in ['fast', 'safe', 'exact']: _weighted_doc_example(precision) + +def test_float_relative_precision(): + assert AlphaComplex.get_float_relative_precision() == 1e-5 + # Must be > 0. + with pytest.raises(ValueError): + AlphaComplex.set_float_relative_precision(0.) + # Must be < 1. + with pytest.raises(ValueError): + AlphaComplex.set_float_relative_precision(1.) + + points = [[1, 1], [7, 0], [4, 6], [9, 6], [0, 14], [2, 19], [9, 17]] + st = AlphaComplex(points=points).create_simplex_tree() + filtrations = list(st.get_filtration()) + + # Get a better precision + AlphaComplex.set_float_relative_precision(1e-15) + assert AlphaComplex.get_float_relative_precision() == 1e-15 + + st = AlphaComplex(points=points).create_simplex_tree() + filtrations_better_resolution = list(st.get_filtration()) + + assert len(filtrations) == len(filtrations_better_resolution) + for idx in range(len(filtrations)): + # check simplex is the same + assert filtrations[idx][0] == filtrations_better_resolution[idx][0] + # check filtration is about the same with a relative precision of the worst case + assert filtrations[idx][1] == pytest.approx(filtrations_better_resolution[idx][1], rel=1e-5) diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py new file mode 100644 index 00000000..c19836b7 --- /dev/null +++ b/src/python/test/test_persistence_graphical_tools.py @@ -0,0 +1,121 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Vincent Rouvreau + + Copyright (C) 2021 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +import gudhi as gd +import numpy as np +import matplotlib as plt +import pytest + + +def test_array_handler(): + diags = np.array([[1, 2], [3, 4], [5, 6]], float) + arr_diags = gd.persistence_graphical_tools._array_handler(diags) + for idx in range(len(diags)): + assert arr_diags[idx][0] == 0 + np.testing.assert_array_equal(arr_diags[idx][1], diags[idx]) + + diags = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)] + arr_diags = gd.persistence_graphical_tools._array_handler(diags) + for idx in range(len(diags)): + assert arr_diags[idx][0] == 0 + assert arr_diags[idx][1] == diags[idx] + + diags = [(0, (1.0, 2.0)), (0, (3.0, 4.0)), (0, (5.0, 6.0))] + assert gd.persistence_graphical_tools._array_handler(diags) == diags + + +def test_min_birth_max_death(): + diags = [ + (0, (0.0, float("inf"))), + (0, (0.0983494, float("inf"))), + (0, (0.0, 0.122545)), + (0, (0.0, 0.12047)), + (0, (0.0, 0.118398)), + (0, (0.118398, 1.0)), + (0, (0.0, 0.117908)), + (0, (0.0, 0.112307)), + (0, (0.0, 0.107535)), + (0, (0.0, 0.106382)), + ] + assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0.0, 1.0) + assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (0.0, 5.0) + + +def test_limit_min_birth_max_death(): + diags = [ + (0, (2.0, float("inf"))), + (0, (2.0, float("inf"))), + ] + assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2.0, 3.0) + assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (2.0, 6.0) + + +def test_limit_to_max_intervals(): + diags = [ + (0, (0.0, float("inf"))), + (0, (0.0983494, float("inf"))), + (0, (0.0, 0.122545)), + (0, (0.0, 0.12047)), + (0, (0.0, 0.118398)), + (0, (0.118398, 1.0)), + (0, (0.0, 0.117908)), + (0, (0.0, 0.112307)), + (0, (0.0, 0.107535)), + (0, (0.0, 0.106382)), + ] + # check no warnings if max_intervals equals to the diagrams number + with pytest.warns(None) as record: + truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals( + diags, 10, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) + # check diagrams are not sorted + assert truncated_diags == diags + assert len(record) == 0 + + # check warning if max_intervals lower than the diagrams number + with pytest.warns(UserWarning) as record: + truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals( + diags, 5, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) + # check diagrams are truncated and sorted by life time + assert truncated_diags == [ + (0, (0.0, float("inf"))), + (0, (0.0983494, float("inf"))), + (0, (0.118398, 1.0)), + (0, (0.0, 0.122545)), + (0, (0.0, 0.12047)), + ] + assert len(record) == 1 + + +def _limit_plot_persistence(function): + pplot = function(persistence=[]) + assert isinstance(pplot, plt.axes.SubplotBase) + pplot = function(persistence=[], legend=True) + assert isinstance(pplot, plt.axes.SubplotBase) + pplot = function(persistence=[(0, float("inf"))]) + assert isinstance(pplot, plt.axes.SubplotBase) + pplot = function(persistence=[(0, float("inf"))], legend=True) + assert isinstance(pplot, plt.axes.SubplotBase) + + +def test_limit_plot_persistence(): + for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]: + _limit_plot_persistence(function) + + +def _non_existing_persistence_file(function): + with pytest.raises(FileNotFoundError): + function(persistence_file="pouetpouettralala.toubiloubabdou") + + +def test_non_existing_persistence_file(): + for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]: + _non_existing_persistence_file(function) diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index f130cf1f..688f4fd6 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -454,3 +454,104 @@ def test_equality_operator(): st2.insert([1,2,3], 4.) assert st1 == st2 + +def test_simplex_tree_deep_copy(): + st = SimplexTree() + st.insert([1, 2, 3], 0.) + # compute persistence only on the original + st.compute_persistence() + + st_copy = st.copy() + assert st_copy == st + st_filt_list = list(st.get_filtration()) + + # check persistence is not copied + assert st.__is_persistence_defined() == True + assert st_copy.__is_persistence_defined() == False + + # remove something in the copy and check the copy is included in the original + st_copy.remove_maximal_simplex([1, 2, 3]) + a_filt_list = list(st_copy.get_filtration()) + assert len(a_filt_list) < len(st_filt_list) + + for a_splx in a_filt_list: + assert a_splx in st_filt_list + + # test double free + del st + del st_copy + +def test_simplex_tree_deep_copy_constructor(): + st = SimplexTree() + st.insert([1, 2, 3], 0.) + # compute persistence only on the original + st.compute_persistence() + + st_copy = SimplexTree(st) + assert st_copy == st + st_filt_list = list(st.get_filtration()) + + # check persistence is not copied + assert st.__is_persistence_defined() == True + assert st_copy.__is_persistence_defined() == False + + # remove something in the copy and check the copy is included in the original + st_copy.remove_maximal_simplex([1, 2, 3]) + a_filt_list = list(st_copy.get_filtration()) + assert len(a_filt_list) < len(st_filt_list) + + for a_splx in a_filt_list: + assert a_splx in st_filt_list + + # test double free + del st + del st_copy + +def test_simplex_tree_constructor_exception(): + with pytest.raises(TypeError): + st = SimplexTree(other = "Construction from a string shall raise an exception") + +def test_expansion_with_blocker(): + st=SimplexTree() + st.insert([0,1],0) + st.insert([0,2],1) + st.insert([0,3],2) + st.insert([1,2],3) + st.insert([1,3],4) + st.insert([2,3],5) + st.insert([2,4],6) + st.insert([3,6],7) + st.insert([4,5],8) + st.insert([4,6],9) + st.insert([5,6],10) + st.insert([6],10) + + def blocker(simplex): + try: + # Block all simplices that countains vertex 6 + simplex.index(6) + print(simplex, ' is blocked') + return True + except ValueError: + print(simplex, ' is accepted') + st.assign_filtration(simplex, st.filtration(simplex) + 1.) + return False + + st.expansion_with_blocker(2, blocker) + assert st.num_simplices() == 22 + assert st.dimension() == 2 + assert st.find([4,5,6]) == False + assert st.filtration([0,1,2]) == 4. + assert st.filtration([0,1,3]) == 5. + assert st.filtration([0,2,3]) == 6. + assert st.filtration([1,2,3]) == 6. + + st.expansion_with_blocker(3, blocker) + assert st.num_simplices() == 23 + assert st.dimension() == 3 + assert st.find([4,5,6]) == False + assert st.filtration([0,1,2]) == 4. + assert st.filtration([0,1,3]) == 5. + assert st.filtration([0,2,3]) == 6. + assert st.filtration([1,2,3]) == 6. + assert st.filtration([0,1,2,3]) == 7. |