diff options
Diffstat (limited to 'src/python')
-rw-r--r-- | src/python/CMakeLists.txt | 16 | ||||
-rw-r--r-- | src/python/doc/alpha_complex_user.rst | 4 | ||||
-rw-r--r-- | src/python/doc/nerve_gic_complex_user.rst | 2 | ||||
-rw-r--r-- | src/python/doc/persistence_graphical_tools_user.rst | 2 | ||||
-rwxr-xr-x | src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py | 2 | ||||
-rw-r--r-- | src/python/gudhi/__init__.py.in | 4 | ||||
-rw-r--r-- | src/python/gudhi/hera/wasserstein.cc | 2 | ||||
-rw-r--r-- | src/python/gudhi/persistence_graphical_tools.py | 349 | ||||
-rw-r--r-- | src/python/gudhi/simplex_tree.pyx | 15 | ||||
-rw-r--r-- | src/python/gudhi/wasserstein/barycenter.py | 6 | ||||
-rw-r--r-- | src/python/include/Simplex_tree_interface.h | 8 | ||||
-rw-r--r-- | src/python/test/test_persistence_graphical_tools.py | 121 | ||||
-rwxr-xr-x | src/python/test/test_simplex_tree.py | 19 | ||||
-rwxr-xr-x | src/python/test/test_subsampling.py | 4 |
14 files changed, 326 insertions, 228 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index e31af02e..af0b6115 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -148,10 +148,6 @@ if(PYTHONINTERP_FOUND) add_gudhi_debug_info("Eigen3 version ${EIGEN3_VERSION}") # No problem, even if no CGAL found set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DCGAL_EIGEN3_ENABLED', ") - set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DGUDHI_USE_EIGEN3', ") - set(GUDHI_USE_EIGEN3 "True") - else (EIGEN3_FOUND) - set(GUDHI_USE_EIGEN3 "False") endif (EIGEN3_FOUND) set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'off_reader', ") @@ -333,9 +329,9 @@ if(PYTHONINTERP_FOUND) if(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0) set (GUDHI_SPHINX_MESSAGE "Generating API documentation with Sphinx in ${CMAKE_CURRENT_BINARY_DIR}/sphinx/") # User warning - Sphinx is a static pages generator, and configured to work fine with user_version - # Images and biblio warnings because not found on developper version + # Images and biblio warnings because not found on developer version if (GUDHI_PYTHON_PATH STREQUAL "src/python") - set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developper version. Images and biblio will miss") + set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developer version. Images and biblio will miss") endif() # sphinx target requires gudhi.so, because conf.py reads gudhi version from it add_custom_target(sphinx @@ -488,7 +484,7 @@ if(PYTHONINTERP_FOUND) add_gudhi_py_test(test_euclidean_witness_complex) # Datasets generators - add_gudhi_py_test(test_datasets_generators) # TODO separate full python datasets generators in another test file independant from CGAL ? + add_gudhi_py_test(test_datasets_generators) # TODO separate full python datasets generators in another test file independent from CGAL ? endif (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) @@ -595,6 +591,10 @@ if(PYTHONINTERP_FOUND) add_gudhi_py_test(test_dtm_rips_complex) endif() + # persistence graphical tools + if(MATPLOTLIB_FOUND) + add_gudhi_py_test(test_persistence_graphical_tools) + endif() # Set missing or not modules set(GUDHI_MODULES ${GUDHI_MODULES} "python" CACHE INTERNAL "GUDHI_MODULES") @@ -605,4 +605,4 @@ if(PYTHONINTERP_FOUND) else(PYTHONINTERP_FOUND) message("++ Python module will not be compiled because no Python interpreter was found") set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python" CACHE INTERNAL "GUDHI_MISSING_MODULES") -endif(PYTHONINTERP_FOUND)
\ No newline at end of file +endif(PYTHONINTERP_FOUND) diff --git a/src/python/doc/alpha_complex_user.rst b/src/python/doc/alpha_complex_user.rst index b060c86e..9e67d38a 100644 --- a/src/python/doc/alpha_complex_user.rst +++ b/src/python/doc/alpha_complex_user.rst @@ -178,11 +178,11 @@ Weighted version ^^^^^^^^^^^^^^^^ A weighted version for Alpha complex is available. It is like a usual Alpha complex, but based on a -`CGAL regular triangulation <https://doc.cgal.org/latest/Triangulation/index.html#title20>`_. +`CGAL regular triangulation <https://doc.cgal.org/latest/Triangulation/index.html#TriangulationSecRT>`_. This example builds the weighted alpha-complex of a small molecule, where atoms have different sizes. It is taken from -`CGAL 3d weighted alpha shapes <https://doc.cgal.org/latest/Alpha_shapes_3/index.html#title13>`_. +`CGAL 3d weighted alpha shapes <https://doc.cgal.org/latest/Alpha_shapes_3/index.html#AlphaShape_3DExampleforWeightedAlphaShapes>`_. Then, it is asked to display information about the alpha complex. diff --git a/src/python/doc/nerve_gic_complex_user.rst b/src/python/doc/nerve_gic_complex_user.rst index 0b820abf..8633cadb 100644 --- a/src/python/doc/nerve_gic_complex_user.rst +++ b/src/python/doc/nerve_gic_complex_user.rst @@ -12,7 +12,7 @@ Definition Visualizations of the simplicial complexes can be done with either neato (from `graphviz <http://www.graphviz.org/>`_), `geomview <http://www.geomview.org/>`_, -`KeplerMapper <https://github.com/MLWave/kepler-mapper>`_. +`KeplerMapper <https://github.com/scikit-tda/kepler-mapper>`_. Input point clouds are assumed to be OFF files (cf. `OFF file format <fileformats.html#off-file-format>`_). Covers diff --git a/src/python/doc/persistence_graphical_tools_user.rst b/src/python/doc/persistence_graphical_tools_user.rst index d95b9d2b..e1d28c71 100644 --- a/src/python/doc/persistence_graphical_tools_user.rst +++ b/src/python/doc/persistence_graphical_tools_user.rst @@ -60,7 +60,7 @@ of shape (N x 2) encoding a persistence diagram (in a given dimension). import matplotlib.pyplot as plt import gudhi import numpy as np - d = np.array([[0, 1], [1, 2], [1, np.inf]]) + d = np.array([[0., 1.], [1., 2.], [1., np.inf]]) gudhi.plot_persistence_diagram(d) plt.show() diff --git a/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py b/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py index ea2eb7e1..0b35dbc5 100755 --- a/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py +++ b/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py @@ -40,7 +40,7 @@ parser.add_argument( args = parser.parse_args() if not (-1.0 < args.min_edge_correlation < 1.0): - print("Wrong value of the treshold corelation (should be between -1 and 1).") + print("Wrong value of the threshold corelation (should be between -1 and 1).") sys.exit(1) print("#####################################################################") diff --git a/src/python/gudhi/__init__.py.in b/src/python/gudhi/__init__.py.in index 3043201a..79e12fbc 100644 --- a/src/python/gudhi/__init__.py.in +++ b/src/python/gudhi/__init__.py.in @@ -23,10 +23,6 @@ __all__ = [@GUDHI_PYTHON_MODULES@ @GUDHI_PYTHON_MODULES_EXTRA@] __available_modules = '' __missing_modules = '' -# For unitary tests purpose -# could use "if 'collapse_edges' in gudhi.__all__" when collapse edges will have a python module -__GUDHI_USE_EIGEN3 = @GUDHI_USE_EIGEN3@ - # Try to import * from gudhi.__module_name for default modules. # Extra modules require an explicit import by the user (mostly because of # unusual dependencies, but also to avoid cluttering namespace gudhi and diff --git a/src/python/gudhi/hera/wasserstein.cc b/src/python/gudhi/hera/wasserstein.cc index 1a21f02f..fa0cf8aa 100644 --- a/src/python/gudhi/hera/wasserstein.cc +++ b/src/python/gudhi/hera/wasserstein.cc @@ -29,7 +29,7 @@ double wasserstein_distance( if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>(); params.internal_p = internal_p; params.delta = delta; - // The extra parameters are purposedly not exposed for now. + // The extra parameters are purposely not exposed for now. return hera::wasserstein_dist(diag1, diag2, params); } diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 848dc03e..21275cdd 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -12,6 +12,9 @@ from os import path from math import isfinite import numpy as np from functools import lru_cache +import warnings +import errno +import os from gudhi.reader_utils import read_persistence_intervals_in_dimension from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension @@ -22,6 +25,7 @@ __license__ = "MIT" _gudhi_matplotlib_use_tex = True + def __min_birth_max_death(persistence, band=0.0): """This function returns (min_birth, max_death) from the persistence. @@ -44,20 +48,46 @@ def __min_birth_max_death(persistence, band=0.0): min_birth = float(interval[1][0]) if band > 0.0: max_death += band + # can happen if only points at inf death + if min_birth == max_death: + max_death = max_death + 1.0 return (min_birth, max_death) def _array_handler(a): - ''' + """ :param a: if array, assumes it is a (n x 2) np.array and return a persistence-compatible list (padding with 0), so that the plot can be performed seamlessly. - ''' - if isinstance(a[0][1], np.float64) or isinstance(a[0][1], float): + """ + if isinstance(a[0][1], (np.floating, float)): return [[0, x] for x in a] else: return a + +def _limit_to_max_intervals(persistence, max_intervals, key): + """This function returns truncated persistence if length is bigger than max_intervals. + :param persistence: Persistence intervals values list. Can be grouped by dimension or not. + :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death). + :param max_intervals: maximal number of intervals to display. + Selected intervals are those with the longest life time. Set it + to 0 to see all. Default value is 1000. + :type max_intervals: int. + :param key: key function for sort algorithm. + :type key: function or lambda. + """ + if max_intervals > 0 and max_intervals < len(persistence): + warnings.warn( + "There are %s intervals given as input, whereas max_intervals is set to %s." + % (len(persistence), max_intervals) + ) + # Sort by life time, then takes only the max_intervals elements + return sorted(persistence, key=key, reverse=True)[:max_intervals] + else: + return persistence + + @lru_cache(maxsize=1) def _matplotlib_can_use_tex(): """This function returns True if matplotlib can deal with LaTeX, False otherwise. @@ -65,17 +95,17 @@ def _matplotlib_can_use_tex(): """ try: from matplotlib import checkdep_usetex + return checkdep_usetex(True) - except ImportError: - print("This function is not available, you may be missing matplotlib.") + except ImportError as import_error: + warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.") def plot_persistence_barcode( persistence=[], persistence_file="", alpha=0.6, - max_intervals=1000, - max_barcodes=1000, + max_intervals=20000, inf_delta=0.1, legend=False, colormap=None, @@ -97,7 +127,7 @@ def plot_persistence_barcode( :type alpha: float. :param max_intervals: maximal number of intervals to display. Selected intervals are those with the longest life time. Set it - to 0 to see all. Default value is 1000. + to 0 to see all. Default value is 20000. :type max_intervals: int. :param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x inf_delta)` above :code:`max_death` value. A reasonable value is @@ -119,99 +149,68 @@ def plot_persistence_barcode( import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib import rc + if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex(): - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + plt.rc("text", usetex=True) + plt.rc("font", family="serif") else: - plt.rc('text', usetex=False) - plt.rc('font', family='DejaVu Sans') + plt.rc("text", usetex=False) + plt.rc("font", family="DejaVu Sans") if persistence_file != "": if path.isfile(persistence_file): # Reset persistence persistence = [] - diag = read_persistence_intervals_grouped_by_dimension( - persistence_file=persistence_file - ) + diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) for key in diag.keys(): for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) else: - print("file " + persistence_file + " not found.") - return None - - persistence = _array_handler(persistence) - - if max_barcodes != 1000: - print("Deprecated parameter. It has been replaced by max_intervals") - max_intervals = max_barcodes - - if max_intervals > 0 and max_intervals < len(persistence): - # Sort by life time, then takes only the max_intervals elements - persistence = sorted( - persistence, - key=lambda life_time: life_time[1][1] - life_time[1][0], - reverse=True, - )[:max_intervals] - - if colormap == None: - colormap = plt.cm.Set1.colors - if axes == None: - fig, axes = plt.subplots(1, 1) + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file) - persistence = sorted(persistence, key=lambda birth: birth[1][0]) + try: + persistence = _array_handler(persistence) + persistence = _limit_to_max_intervals( + persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) + (min_birth, max_death) = __min_birth_max_death(persistence) + persistence = sorted(persistence, key=lambda birth: birth[1][0]) + except IndexError: + min_birth, max_death = 0.0, 1.0 + pass - (min_birth, max_death) = __min_birth_max_death(persistence) - ind = 0 delta = (max_death - min_birth) * inf_delta # Replace infinity values with max_death + delta for bar code to be more # readable infinity = max_death + delta axis_start = min_birth - delta - # Draw horizontal bars in loop - for interval in reversed(persistence): - if float(interval[1][1]) != float("inf"): - # Finite death case - axes.barh( - ind, - (interval[1][1] - interval[1][0]), - height=0.8, - left=interval[1][0], - alpha=alpha, - color=colormap[interval[0]], - linewidth=0, - ) - else: - # Infinite death case for diagram to be nicer - axes.barh( - ind, - (infinity - interval[1][0]), - height=0.8, - left=interval[1][0], - alpha=alpha, - color=colormap[interval[0]], - linewidth=0, - ) - ind = ind + 1 + + if axes == None: + _, axes = plt.subplots(1, 1) + if colormap == None: + colormap = plt.cm.Set1.colors + + x=[birth for (dim,(birth,death)) in persistence] + y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence] + c=[colormap[dim] for (dim,(birth,death)) in persistence] + + axes.barh(list(reversed(range(len(x)))), y, height=0.8, left=x, alpha=alpha, color=c, linewidth=0) if legend: dimensions = list(set(item[0] for item in persistence)) axes.legend( - handles=[ - mpatches.Patch(color=colormap[dim], label=str(dim)) - for dim in dimensions - ], - loc="lower right", + handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right", ) axes.set_title("Persistence barcode", fontsize=fontsize) # Ends plot on infinity value and starts a little bit before min_birth - axes.axis([axis_start, infinity, 0, ind]) + if len(x) != 0: + axes.axis([axis_start, infinity, 0, len(x)]) return axes - except ImportError: - print("This function is not available, you may be missing matplotlib.") + except ImportError as import_error: + warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.") def plot_persistence_diagram( @@ -219,14 +218,13 @@ def plot_persistence_diagram( persistence_file="", alpha=0.6, band=0.0, - max_intervals=1000, - max_plots=1000, + max_intervals=1000000, inf_delta=0.1, legend=False, colormap=None, axes=None, fontsize=16, - greyblock=True + greyblock=True, ): """This function plots the persistence diagram from persistence values list, a np.array of shape (N x 2) representing a diagram in a single @@ -244,7 +242,7 @@ def plot_persistence_diagram( :type band: float. :param max_intervals: maximal number of intervals to display. Selected intervals are those with the longest life time. Set it - to 0 to see all. Default value is 1000. + to 0 to see all. Default value is 1000000. :type max_intervals: int. :param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x inf_delta)` above :code:`max_death` value. A reasonable value is @@ -268,47 +266,35 @@ def plot_persistence_diagram( import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib import rc + if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex(): - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + plt.rc("text", usetex=True) + plt.rc("font", family="serif") else: - plt.rc('text', usetex=False) - plt.rc('font', family='DejaVu Sans') + plt.rc("text", usetex=False) + plt.rc("font", family="DejaVu Sans") if persistence_file != "": if path.isfile(persistence_file): # Reset persistence persistence = [] - diag = read_persistence_intervals_grouped_by_dimension( - persistence_file=persistence_file - ) + diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) for key in diag.keys(): for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) else: - print("file " + persistence_file + " not found.") - return None - - persistence = _array_handler(persistence) + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file) - if max_plots != 1000: - print("Deprecated parameter. It has been replaced by max_intervals") - max_intervals = max_plots - - if max_intervals > 0 and max_intervals < len(persistence): - # Sort by life time, then takes only the max_intervals elements - persistence = sorted( - persistence, - key=lambda life_time: life_time[1][1] - life_time[1][0], - reverse=True, - )[:max_intervals] - - if colormap == None: - colormap = plt.cm.Set1.colors - if axes == None: - fig, axes = plt.subplots(1, 1) + try: + persistence = _array_handler(persistence) + persistence = _limit_to_max_intervals( + persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) + min_birth, max_death = __min_birth_max_death(persistence, band) + except IndexError: + min_birth, max_death = 0.0, 1.0 + pass - (min_birth, max_death) = __min_birth_max_death(persistence, band) delta = (max_death - min_birth) * inf_delta # Replace infinity values with max_death + delta for diagram to be more # readable @@ -316,61 +302,56 @@ def plot_persistence_diagram( axis_end = max_death + delta / 2 axis_start = min_birth - delta + if axes == None: + _, axes = plt.subplots(1, 1) + if colormap == None: + colormap = plt.cm.Set1.colors # bootstrap band if band > 0.0: x = np.linspace(axis_start, infinity, 1000) axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red") # lower diag patch if greyblock: - axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey')) - # Draw points in loop - pts_at_infty = False # Records presence of pts at infty - for interval in reversed(persistence): - if float(interval[1][1]) != float("inf"): - # Finite death case - axes.scatter( - interval[1][0], - interval[1][1], - alpha=alpha, - color=colormap[interval[0]], + axes.add_patch( + mpatches.Polygon( + [[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], + fill=True, + color="lightgrey", ) - else: - pts_at_infty = True - # Infinite death case for diagram to be nicer - axes.scatter( - interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]] - ) - if pts_at_infty: + ) + # line display of equation : birth = death + axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") + + x=[birth for (dim,(birth,death)) in persistence] + y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence] + c=[colormap[dim] for (dim,(birth,death)) in persistence] + + axes.scatter(x,y,alpha=alpha,color=c) + if float("inf") in (death for (dim,(birth,death)) in persistence): # infinity line and text - axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha) # Infinity label yt = axes.get_yticks() - yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity + yt = yt[np.where(yt < axis_end)] # to avoid plotting ticklabel higher than infinity yt = np.append(yt, infinity) ytl = ["%.3f" % e for e in yt] # to avoid float precision error - ytl[-1] = r'$+\infty$' + ytl[-1] = r"$+\infty$" axes.set_yticks(yt) axes.set_yticklabels(ytl) if legend: dimensions = list(set(item[0] for item in persistence)) - axes.legend( - handles=[ - mpatches.Patch(color=colormap[dim], label=str(dim)) - for dim in dimensions - ] - ) + axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions]) axes.set_xlabel("Birth", fontsize=fontsize) axes.set_ylabel("Death", fontsize=fontsize) axes.set_title("Persistence diagram", fontsize=fontsize) # Ends plot on infinity value and starts a little bit before min_birth - axes.axis([axis_start, axis_end, axis_start, infinity + delta/2]) + axes.axis([axis_start, axis_end, axis_start, infinity + delta / 2]) return axes - except ImportError: - print("This function is not available, you may be missing matplotlib.") + except ImportError as import_error: + warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.") def plot_persistence_density( @@ -384,7 +365,7 @@ def plot_persistence_density( legend=False, axes=None, fontsize=16, - greyblock=False + greyblock=False, ): """This function plots the persistence density from persistence values list, np.array of shape (N x 2) representing a diagram @@ -444,12 +425,13 @@ def plot_persistence_density( import matplotlib.patches as mpatches from scipy.stats import kde from matplotlib import rc + if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex(): - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + plt.rc("text", usetex=True) + plt.rc("font", family="serif") else: - plt.rc('text', usetex=False) - plt.rc('font', family='DejaVu Sans') + plt.rc("text", usetex=False) + plt.rc("font", family="DejaVu Sans") if persistence_file != "": if dimension is None: @@ -460,10 +442,16 @@ def plot_persistence_density( persistence_file=persistence_file, only_this_dim=dimension ) else: - print("file " + persistence_file + " not found.") - return None + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file) + + # default cmap value cannot be done at argument definition level as matplotlib is not yet defined. + if cmap is None: + cmap = plt.cm.hot_r + if axes == None: + _, axes = plt.subplots(1, 1) - if len(persistence) > 0: + try: + # if not read from file but given by an argument persistence = _array_handler(persistence) persistence_dim = np.array( [ @@ -472,47 +460,54 @@ def plot_persistence_density( if (dim_interval[0] == dimension) or (dimension is None) ] ) - - persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])] - if max_intervals > 0 and max_intervals < len(persistence_dim): - # Sort by life time, then takes only the max_intervals elements + persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])] persistence_dim = np.array( - sorted( - persistence_dim, - key=lambda life_time: life_time[1] - life_time[0], - reverse=True, - )[:max_intervals] + _limit_to_max_intervals( + persistence_dim, max_intervals, key=lambda life_time: life_time[1] - life_time[0] + ) ) - # Set as numpy array birth and death (remove undefined values - inf and NaN) - birth = persistence_dim[:, 0] - death = persistence_dim[:, 1] - - # default cmap value cannot be done at argument definition level as matplotlib is not yet defined. - if cmap is None: - cmap = plt.cm.hot_r - if axes == None: - fig, axes = plt.subplots(1, 1) + # Set as numpy array birth and death (remove undefined values - inf and NaN) + birth = persistence_dim[:, 0] + death = persistence_dim[:, 1] + birth_min = birth.min() + birth_max = birth.max() + death_min = death.min() + death_max = death.max() + + # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents + k = kde.gaussian_kde([birth, death], bw_method=bw_method) + xi, yi = np.mgrid[ + birth_min : birth_max : nbins * 1j, death_min : death_max : nbins * 1j, + ] + zi = k(np.vstack([xi.flatten(), yi.flatten()])) + # Make the plot + img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto") + plot_success = True + + # IndexError on empty diagrams, ValueError on only inf death values + except (IndexError, ValueError): + birth_min = 0.0 + birth_max = 1.0 + death_min = 0.0 + death_max = 1.0 + plot_success = False + pass # line display of equation : birth = death - x = np.linspace(death.min(), birth.max(), 1000) + x = np.linspace(death_min, birth_max, 1000) axes.plot(x, x, color="k", linewidth=1.0) - # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents - k = kde.gaussian_kde([birth, death], bw_method=bw_method) - xi, yi = np.mgrid[ - birth.min() : birth.max() : nbins * 1j, - death.min() : death.max() : nbins * 1j, - ] - zi = k(np.vstack([xi.flatten(), yi.flatten()])) - - # Make the plot - img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap) - if greyblock: - axes.add_patch(mpatches.Polygon([[birth.min(), birth.min()], [death.max(), birth.min()], [death.max(), death.max()]], fill=True, color='lightgrey')) + axes.add_patch( + mpatches.Polygon( + [[birth_min, birth_min], [death_max, birth_min], [death_max, death_max]], + fill=True, + color="lightgrey", + ) + ) - if legend: + if plot_success and legend: plt.colorbar(img, ax=axes) axes.set_xlabel("Birth", fontsize=fontsize) @@ -521,7 +516,5 @@ def plot_persistence_density( return axes - except ImportError: - print( - "This function is not available, you may be missing matplotlib and/or scipy." - ) + except ImportError as import_error: + warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.") diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 97c7d2c8..521a7763 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -675,18 +675,17 @@ cdef class SimplexTree: return (normal0, normals, infinite0, infinites) def collapse_edges(self, nb_iterations = 1): - """Assuming the simplex tree is a 1-skeleton graph, this method collapse edges (simplices of higher dimension - are ignored) and resets the simplex tree from the remaining edges. - A good candidate is to build a simplex tree on top of a :class:`~gudhi.RipsComplex` of dimension 1 before - collapsing edges + """Assuming the complex is a graph (simplices of higher dimension are ignored), this method implicitly + interprets it as the 1-skeleton of a flag complex, and replaces it with another (smaller) graph whose + expansion has the same persistent homology, using a technique known as edge collapses + (see :cite:`edgecollapsearxiv`). + + A natural application is to get a simplex tree of dimension 1 from :class:`~gudhi.RipsComplex`, + then collapse edges, perform :meth:`expansion()` and finally compute persistence (cf. :download:`rips_complex_edge_collapse_example.py <../example/rips_complex_edge_collapse_example.py>`). - For implementation details, please refer to :cite:`edgecollapsesocg2020`. :param nb_iterations: The number of edge collapse iterations to perform. Default is 1. :type nb_iterations: int - - :note: collapse_edges method requires `Eigen <installation.html#eigen>`_ >= 3.1.0 and an exception is thrown - if this method is not available. """ # Backup old pointer cdef Simplex_tree_interface_full_featured* ptr = self.get_ptr() diff --git a/src/python/gudhi/wasserstein/barycenter.py b/src/python/gudhi/wasserstein/barycenter.py index d67bcde7..bb6e641e 100644 --- a/src/python/gudhi/wasserstein/barycenter.py +++ b/src/python/gudhi/wasserstein/barycenter.py @@ -37,7 +37,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False): :param init: The initial value for barycenter estimate. If ``None``, init is made on a random diagram from the dataset. Otherwise, it can be an ``int`` (then initialization is made on ``pdiagset[init]``) - or a `(n x 2)` ``numpy.array`` enconding a persistence diagram with `n` points. + or a `(n x 2)` ``numpy.array`` encoding a persistence diagram with `n` points. :type init: ``int``, or (n x 2) ``np.array`` :param verbose: if ``True``, returns additional information about the barycenter. :type verbose: boolean @@ -45,7 +45,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False): (local minimum of the energy function). If ``pdiagset`` is empty, returns ``None``. If verbose, returns a couple ``(Y, log)`` where ``Y`` is the barycenter estimate, - and ``log`` is a ``dict`` that contains additional informations: + and ``log`` is a ``dict`` that contains additional information: - `"groupings"`, a list of list of pairs ``(i,j)``. Namely, ``G[k] = [...(i, j)...]``, where ``(i,j)`` indicates that `pdiagset[k][i]`` is matched to ``Y[j]`` if ``i = -1`` or ``j = -1``, it means they represent the diagonal. @@ -73,7 +73,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False): nb_iter = 0 - converged = False # stoping criterion + converged = False # stopping criterion while not converged: nb_iter += 1 K = len(Y) # current nb of points in Y (some might be on diagonal) diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h index aec29f9b..3848c5ad 100644 --- a/src/python/include/Simplex_tree_interface.h +++ b/src/python/include/Simplex_tree_interface.h @@ -15,9 +15,7 @@ #include <gudhi/distance_functions.h> #include <gudhi/Simplex_tree.h> #include <gudhi/Points_off_io.h> -#ifdef GUDHI_USE_EIGEN3 #include <gudhi/Flag_complex_edge_collapser.h> -#endif #include <iostream> #include <vector> @@ -135,7 +133,6 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> { } Simplex_tree_interface* collapse_edges(int nb_collapse_iteration) { -#ifdef GUDHI_USE_EIGEN3 using Filtered_edge = std::tuple<Vertex_handle, Vertex_handle, Filtration_value>; std::vector<Filtered_edge> edges; for (Simplex_handle sh : Base::skeleton_simplex_range(1)) { @@ -149,7 +146,7 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> { } for (int iteration = 0; iteration < nb_collapse_iteration; iteration++) { - edges = Gudhi::collapse::flag_complex_collapse_edges(edges); + edges = Gudhi::collapse::flag_complex_collapse_edges(std::move(edges)); } Simplex_tree_interface* collapsed_stree_ptr = new Simplex_tree_interface(); // Copy the original 0-skeleton @@ -161,9 +158,6 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> { collapsed_stree_ptr->insert({std::get<0>(remaining_edge), std::get<1>(remaining_edge)}, std::get<2>(remaining_edge)); } return collapsed_stree_ptr; -#else - throw std::runtime_error("Unable to collapse edges as it requires Eigen3 >= 3.1.0."); -#endif } void expansion_with_blockers_callback(int dimension, blocker_func_t user_func, void *user_data) { diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py new file mode 100644 index 00000000..c19836b7 --- /dev/null +++ b/src/python/test/test_persistence_graphical_tools.py @@ -0,0 +1,121 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Vincent Rouvreau + + Copyright (C) 2021 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +import gudhi as gd +import numpy as np +import matplotlib as plt +import pytest + + +def test_array_handler(): + diags = np.array([[1, 2], [3, 4], [5, 6]], float) + arr_diags = gd.persistence_graphical_tools._array_handler(diags) + for idx in range(len(diags)): + assert arr_diags[idx][0] == 0 + np.testing.assert_array_equal(arr_diags[idx][1], diags[idx]) + + diags = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)] + arr_diags = gd.persistence_graphical_tools._array_handler(diags) + for idx in range(len(diags)): + assert arr_diags[idx][0] == 0 + assert arr_diags[idx][1] == diags[idx] + + diags = [(0, (1.0, 2.0)), (0, (3.0, 4.0)), (0, (5.0, 6.0))] + assert gd.persistence_graphical_tools._array_handler(diags) == diags + + +def test_min_birth_max_death(): + diags = [ + (0, (0.0, float("inf"))), + (0, (0.0983494, float("inf"))), + (0, (0.0, 0.122545)), + (0, (0.0, 0.12047)), + (0, (0.0, 0.118398)), + (0, (0.118398, 1.0)), + (0, (0.0, 0.117908)), + (0, (0.0, 0.112307)), + (0, (0.0, 0.107535)), + (0, (0.0, 0.106382)), + ] + assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0.0, 1.0) + assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (0.0, 5.0) + + +def test_limit_min_birth_max_death(): + diags = [ + (0, (2.0, float("inf"))), + (0, (2.0, float("inf"))), + ] + assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2.0, 3.0) + assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (2.0, 6.0) + + +def test_limit_to_max_intervals(): + diags = [ + (0, (0.0, float("inf"))), + (0, (0.0983494, float("inf"))), + (0, (0.0, 0.122545)), + (0, (0.0, 0.12047)), + (0, (0.0, 0.118398)), + (0, (0.118398, 1.0)), + (0, (0.0, 0.117908)), + (0, (0.0, 0.112307)), + (0, (0.0, 0.107535)), + (0, (0.0, 0.106382)), + ] + # check no warnings if max_intervals equals to the diagrams number + with pytest.warns(None) as record: + truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals( + diags, 10, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) + # check diagrams are not sorted + assert truncated_diags == diags + assert len(record) == 0 + + # check warning if max_intervals lower than the diagrams number + with pytest.warns(UserWarning) as record: + truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals( + diags, 5, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) + # check diagrams are truncated and sorted by life time + assert truncated_diags == [ + (0, (0.0, float("inf"))), + (0, (0.0983494, float("inf"))), + (0, (0.118398, 1.0)), + (0, (0.0, 0.122545)), + (0, (0.0, 0.12047)), + ] + assert len(record) == 1 + + +def _limit_plot_persistence(function): + pplot = function(persistence=[]) + assert isinstance(pplot, plt.axes.SubplotBase) + pplot = function(persistence=[], legend=True) + assert isinstance(pplot, plt.axes.SubplotBase) + pplot = function(persistence=[(0, float("inf"))]) + assert isinstance(pplot, plt.axes.SubplotBase) + pplot = function(persistence=[(0, float("inf"))], legend=True) + assert isinstance(pplot, plt.axes.SubplotBase) + + +def test_limit_plot_persistence(): + for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]: + _limit_plot_persistence(function) + + +def _non_existing_persistence_file(function): + with pytest.raises(FileNotFoundError): + function(persistence_file="pouetpouettralala.toubiloubabdou") + + +def test_non_existing_persistence_file(): + for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]: + _non_existing_persistence_file(function) diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 7abb4366..54bafed5 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -8,7 +8,7 @@ - YYYY/MM Author: Description of the modification """ -from gudhi import SimplexTree, __GUDHI_USE_EIGEN3 +from gudhi import SimplexTree import numpy as np import pytest @@ -358,16 +358,11 @@ def test_collapse_edges(): assert st.num_simplices() == 10 - if __GUDHI_USE_EIGEN3: - st.collapse_edges() - assert st.num_simplices() == 9 - assert st.find([1, 3]) == False - for simplex in st.get_skeleton(0): - assert simplex[1] == 1. - else: - # If no Eigen3, collapse_edges throws an exception - with pytest.raises(RuntimeError): - st.collapse_edges() + st.collapse_edges() + assert st.num_simplices() == 9 + assert st.find([0, 2]) == False # [1, 3] would be fine as well + for simplex in st.get_skeleton(0): + assert simplex[1] == 1. def test_reset_filtration(): st = SimplexTree() @@ -537,7 +532,7 @@ def test_expansion_with_blocker(): def blocker(simplex): try: - # Block all simplices that countains vertex 6 + # Block all simplices that contain vertex 6 simplex.index(6) print(simplex, ' is blocked') return True diff --git a/src/python/test/test_subsampling.py b/src/python/test/test_subsampling.py index 4019852e..3431f372 100755 --- a/src/python/test/test_subsampling.py +++ b/src/python/test/test_subsampling.py @@ -91,7 +91,7 @@ def test_simple_choose_n_farthest_points_randomed(): assert gudhi.choose_n_farthest_points(points=[], nb_points=1) == [] assert gudhi.choose_n_farthest_points(points=point_set, nb_points=0) == [] - # Go furter than point set on purpose + # Go further than point set on purpose for iter in range(1, 10): sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=iter) for sub in sub_set: @@ -117,7 +117,7 @@ def test_simple_pick_n_random_points(): assert gudhi.pick_n_random_points(points=[], nb_points=1) == [] assert gudhi.pick_n_random_points(points=point_set, nb_points=0) == [] - # Go furter than point set on purpose + # Go further than point set on purpose for iter in range(1, 10): sub_set = gudhi.pick_n_random_points(points=point_set, nb_points=iter) for sub in sub_set: |