diff options
Diffstat (limited to 'src/python/gudhi/persistence_graphical_tools.py')
-rw-r--r-- | src/python/gudhi/persistence_graphical_tools.py | 368 |
1 files changed, 197 insertions, 171 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 848dc03e..70d550ab 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -12,6 +12,9 @@ from os import path from math import isfinite import numpy as np from functools import lru_cache +import warnings +import errno +import os from gudhi.reader_utils import read_persistence_intervals_in_dimension from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension @@ -22,6 +25,7 @@ __license__ = "MIT" _gudhi_matplotlib_use_tex = True + def __min_birth_max_death(persistence, band=0.0): """This function returns (min_birth, max_death) from the persistence. @@ -44,20 +48,46 @@ def __min_birth_max_death(persistence, band=0.0): min_birth = float(interval[1][0]) if band > 0.0: max_death += band + # can happen if only points at inf death + if min_birth == max_death: + max_death = max_death + 1.0 return (min_birth, max_death) def _array_handler(a): - ''' + """ :param a: if array, assumes it is a (n x 2) np.array and return a persistence-compatible list (padding with 0), so that the plot can be performed seamlessly. - ''' - if isinstance(a[0][1], np.float64) or isinstance(a[0][1], float): + """ + if isinstance(a[0][1], (np.floating, float)): return [[0, x] for x in a] else: return a + +def _limit_to_max_intervals(persistence, max_intervals, key): + """This function returns truncated persistence if length is bigger than max_intervals. + :param persistence: Persistence intervals values list. Can be grouped by dimension or not. + :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death). + :param max_intervals: maximal number of intervals to display. + Selected intervals are those with the longest life time. Set it + to 0 to see all. Default value is 1000. + :type max_intervals: int. + :param key: key function for sort algorithm. + :type key: function or lambda. + """ + if max_intervals > 0 and max_intervals < len(persistence): + warnings.warn( + "There are %s intervals given as input, whereas max_intervals is set to %s." + % (len(persistence), max_intervals) + ) + # Sort by life time, then takes only the max_intervals elements + return sorted(persistence, key=key, reverse=True)[:max_intervals] + else: + return persistence + + @lru_cache(maxsize=1) def _matplotlib_can_use_tex(): """This function returns True if matplotlib can deal with LaTeX, False otherwise. @@ -65,9 +95,10 @@ def _matplotlib_can_use_tex(): """ try: from matplotlib import checkdep_usetex + return checkdep_usetex(True) except ImportError: - print("This function is not available, you may be missing matplotlib.") + warnings.warn("This function is not available, you may be missing matplotlib.") def plot_persistence_barcode( @@ -75,7 +106,6 @@ def plot_persistence_barcode( persistence_file="", alpha=0.6, max_intervals=1000, - max_barcodes=1000, inf_delta=0.1, legend=False, colormap=None, @@ -119,99 +149,92 @@ def plot_persistence_barcode( import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib import rc + if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex(): - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + plt.rc("text", usetex=True) + plt.rc("font", family="serif") else: - plt.rc('text', usetex=False) - plt.rc('font', family='DejaVu Sans') + plt.rc("text", usetex=False) + plt.rc("font", family="DejaVu Sans") if persistence_file != "": if path.isfile(persistence_file): # Reset persistence persistence = [] - diag = read_persistence_intervals_grouped_by_dimension( - persistence_file=persistence_file - ) + diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) for key in diag.keys(): for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) else: - print("file " + persistence_file + " not found.") - return None - - persistence = _array_handler(persistence) - - if max_barcodes != 1000: - print("Deprecated parameter. It has been replaced by max_intervals") - max_intervals = max_barcodes - - if max_intervals > 0 and max_intervals < len(persistence): - # Sort by life time, then takes only the max_intervals elements - persistence = sorted( - persistence, - key=lambda life_time: life_time[1][1] - life_time[1][0], - reverse=True, - )[:max_intervals] - - if colormap == None: - colormap = plt.cm.Set1.colors - if axes == None: - fig, axes = plt.subplots(1, 1) + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file) - persistence = sorted(persistence, key=lambda birth: birth[1][0]) + try: + persistence = _array_handler(persistence) + persistence = _limit_to_max_intervals( + persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) + (min_birth, max_death) = __min_birth_max_death(persistence) + persistence = sorted(persistence, key=lambda birth: birth[1][0]) + except IndexError: + min_birth, max_death = 0.0, 1.0 + pass - (min_birth, max_death) = __min_birth_max_death(persistence) - ind = 0 delta = (max_death - min_birth) * inf_delta # Replace infinity values with max_death + delta for bar code to be more # readable infinity = max_death + delta axis_start = min_birth - delta + + if axes == None: + _, axes = plt.subplots(1, 1) + if colormap == None: + colormap = plt.cm.Set1.colors + ind = 0 + # Draw horizontal bars in loop for interval in reversed(persistence): - if float(interval[1][1]) != float("inf"): - # Finite death case - axes.barh( - ind, - (interval[1][1] - interval[1][0]), - height=0.8, - left=interval[1][0], - alpha=alpha, - color=colormap[interval[0]], - linewidth=0, - ) - else: - # Infinite death case for diagram to be nicer - axes.barh( - ind, - (infinity - interval[1][0]), - height=0.8, - left=interval[1][0], - alpha=alpha, - color=colormap[interval[0]], - linewidth=0, - ) - ind = ind + 1 + try: + if float(interval[1][1]) != float("inf"): + # Finite death case + axes.barh( + ind, + (interval[1][1] - interval[1][0]), + height=0.8, + left=interval[1][0], + alpha=alpha, + color=colormap[interval[0]], + linewidth=0, + ) + else: + # Infinite death case for diagram to be nicer + axes.barh( + ind, + (infinity - interval[1][0]), + height=0.8, + left=interval[1][0], + alpha=alpha, + color=colormap[interval[0]], + linewidth=0, + ) + ind = ind + 1 + except IndexError: + pass if legend: dimensions = list(set(item[0] for item in persistence)) axes.legend( - handles=[ - mpatches.Patch(color=colormap[dim], label=str(dim)) - for dim in dimensions - ], - loc="lower right", + handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right", ) axes.set_title("Persistence barcode", fontsize=fontsize) # Ends plot on infinity value and starts a little bit before min_birth - axes.axis([axis_start, infinity, 0, ind]) + if ind != 0: + axes.axis([axis_start, infinity, 0, ind]) return axes except ImportError: - print("This function is not available, you may be missing matplotlib.") + warnings.warn("This function is not available, you may be missing matplotlib.") def plot_persistence_diagram( @@ -220,13 +243,12 @@ def plot_persistence_diagram( alpha=0.6, band=0.0, max_intervals=1000, - max_plots=1000, inf_delta=0.1, legend=False, colormap=None, axes=None, fontsize=16, - greyblock=True + greyblock=True, ): """This function plots the persistence diagram from persistence values list, a np.array of shape (N x 2) representing a diagram in a single @@ -268,47 +290,35 @@ def plot_persistence_diagram( import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib import rc + if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex(): - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + plt.rc("text", usetex=True) + plt.rc("font", family="serif") else: - plt.rc('text', usetex=False) - plt.rc('font', family='DejaVu Sans') + plt.rc("text", usetex=False) + plt.rc("font", family="DejaVu Sans") if persistence_file != "": if path.isfile(persistence_file): # Reset persistence persistence = [] - diag = read_persistence_intervals_grouped_by_dimension( - persistence_file=persistence_file - ) + diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) for key in diag.keys(): for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) else: - print("file " + persistence_file + " not found.") - return None - - persistence = _array_handler(persistence) - - if max_plots != 1000: - print("Deprecated parameter. It has been replaced by max_intervals") - max_intervals = max_plots + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file) - if max_intervals > 0 and max_intervals < len(persistence): - # Sort by life time, then takes only the max_intervals elements - persistence = sorted( - persistence, - key=lambda life_time: life_time[1][1] - life_time[1][0], - reverse=True, - )[:max_intervals] - - if colormap == None: - colormap = plt.cm.Set1.colors - if axes == None: - fig, axes = plt.subplots(1, 1) + try: + persistence = _array_handler(persistence) + persistence = _limit_to_max_intervals( + persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) + min_birth, max_death = __min_birth_max_death(persistence, band) + except IndexError: + min_birth, max_death = 0.0, 1.0 + pass - (min_birth, max_death) = __min_birth_max_death(persistence, band) delta = (max_death - min_birth) * inf_delta # Replace infinity values with max_death + delta for diagram to be more # readable @@ -316,61 +326,66 @@ def plot_persistence_diagram( axis_end = max_death + delta / 2 axis_start = min_birth - delta + if axes == None: + _, axes = plt.subplots(1, 1) + if colormap == None: + colormap = plt.cm.Set1.colors # bootstrap band if band > 0.0: x = np.linspace(axis_start, infinity, 1000) axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red") # lower diag patch if greyblock: - axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey')) + axes.add_patch( + mpatches.Polygon( + [[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], + fill=True, + color="lightgrey", + ) + ) + # line display of equation : birth = death + axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") + # Draw points in loop pts_at_infty = False # Records presence of pts at infty for interval in reversed(persistence): - if float(interval[1][1]) != float("inf"): - # Finite death case - axes.scatter( - interval[1][0], - interval[1][1], - alpha=alpha, - color=colormap[interval[0]], - ) - else: - pts_at_infty = True - # Infinite death case for diagram to be nicer - axes.scatter( - interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]] - ) + try: + if float(interval[1][1]) != float("inf"): + # Finite death case + axes.scatter( + interval[1][0], interval[1][1], alpha=alpha, color=colormap[interval[0]], + ) + else: + pts_at_infty = True + # Infinite death case for diagram to be nicer + axes.scatter(interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]) + except IndexError: + pass if pts_at_infty: # infinity line and text - axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha) # Infinity label yt = axes.get_yticks() - yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity + yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity yt = np.append(yt, infinity) ytl = ["%.3f" % e for e in yt] # to avoid float precision error - ytl[-1] = r'$+\infty$' + ytl[-1] = r"$+\infty$" axes.set_yticks(yt) axes.set_yticklabels(ytl) if legend: dimensions = list(set(item[0] for item in persistence)) - axes.legend( - handles=[ - mpatches.Patch(color=colormap[dim], label=str(dim)) - for dim in dimensions - ] - ) + axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions]) axes.set_xlabel("Birth", fontsize=fontsize) axes.set_ylabel("Death", fontsize=fontsize) axes.set_title("Persistence diagram", fontsize=fontsize) # Ends plot on infinity value and starts a little bit before min_birth - axes.axis([axis_start, axis_end, axis_start, infinity + delta/2]) + axes.axis([axis_start, axis_end, axis_start, infinity + delta / 2]) return axes except ImportError: - print("This function is not available, you may be missing matplotlib.") + warnings.warn("This function is not available, you may be missing matplotlib.") def plot_persistence_density( @@ -384,7 +399,7 @@ def plot_persistence_density( legend=False, axes=None, fontsize=16, - greyblock=False + greyblock=False, ): """This function plots the persistence density from persistence values list, np.array of shape (N x 2) representing a diagram @@ -444,12 +459,13 @@ def plot_persistence_density( import matplotlib.patches as mpatches from scipy.stats import kde from matplotlib import rc + if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex(): - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + plt.rc("text", usetex=True) + plt.rc("font", family="serif") else: - plt.rc('text', usetex=False) - plt.rc('font', family='DejaVu Sans') + plt.rc("text", usetex=False) + plt.rc("font", family="DejaVu Sans") if persistence_file != "": if dimension is None: @@ -460,57 +476,69 @@ def plot_persistence_density( persistence_file=persistence_file, only_this_dim=dimension ) else: - print("file " + persistence_file + " not found.") - return None - - if len(persistence) > 0: - persistence = _array_handler(persistence) - persistence_dim = np.array( - [ - (dim_interval[1][0], dim_interval[1][1]) - for dim_interval in persistence - if (dim_interval[0] == dimension) or (dimension is None) - ] - ) - - persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])] - if max_intervals > 0 and max_intervals < len(persistence_dim): - # Sort by life time, then takes only the max_intervals elements - persistence_dim = np.array( - sorted( - persistence_dim, - key=lambda life_time: life_time[1] - life_time[0], - reverse=True, - )[:max_intervals] - ) - - # Set as numpy array birth and death (remove undefined values - inf and NaN) - birth = persistence_dim[:, 0] - death = persistence_dim[:, 1] + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file) # default cmap value cannot be done at argument definition level as matplotlib is not yet defined. if cmap is None: cmap = plt.cm.hot_r if axes == None: - fig, axes = plt.subplots(1, 1) + _, axes = plt.subplots(1, 1) + + try: + if len(persistence) > 0: + # if not read from file but given by an argument + persistence = _array_handler(persistence) + persistence_dim = np.array( + [ + (dim_interval[1][0], dim_interval[1][1]) + for dim_interval in persistence + if (dim_interval[0] == dimension) or (dimension is None) + ] + ) + persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])] + persistence_dim = np.array( + _limit_to_max_intervals( + persistence_dim, max_intervals, key=lambda life_time: life_time[1] - life_time[0] + ) + ) + + # Set as numpy array birth and death (remove undefined values - inf and NaN) + birth = persistence_dim[:, 0] + death = persistence_dim[:, 1] + birth_min = birth.min() + birth_max = birth.max() + death_min = death.min() + death_max = death.max() + + # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents + k = kde.gaussian_kde([birth, death], bw_method=bw_method) + xi, yi = np.mgrid[ + birth_min : birth_max : nbins * 1j, death_min : death_max : nbins * 1j, + ] + zi = k(np.vstack([xi.flatten(), yi.flatten()])) + # Make the plot + img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto") + + # IndexError on empty diagrams, ValueError on only inf death values + except (IndexError, ValueError): + birth_min = 0.0 + birth_max = 1.0 + death_min = 0.0 + death_max = 1.0 + pass # line display of equation : birth = death - x = np.linspace(death.min(), birth.max(), 1000) + x = np.linspace(death_min, birth_max, 1000) axes.plot(x, x, color="k", linewidth=1.0) - # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents - k = kde.gaussian_kde([birth, death], bw_method=bw_method) - xi, yi = np.mgrid[ - birth.min() : birth.max() : nbins * 1j, - death.min() : death.max() : nbins * 1j, - ] - zi = k(np.vstack([xi.flatten(), yi.flatten()])) - - # Make the plot - img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap) - if greyblock: - axes.add_patch(mpatches.Polygon([[birth.min(), birth.min()], [death.max(), birth.min()], [death.max(), death.max()]], fill=True, color='lightgrey')) + axes.add_patch( + mpatches.Polygon( + [[birth_min, birth_min], [death_max, birth_min], [death_max, death_max]], + fill=True, + color="lightgrey", + ) + ) if legend: plt.colorbar(img, ax=axes) @@ -522,6 +550,4 @@ def plot_persistence_density( return axes except ImportError: - print( - "This function is not available, you may be missing matplotlib and/or scipy." - ) + warnings.warn("This function is not available, you may be missing matplotlib and/or scipy.") |