diff options
Diffstat (limited to 'src/cython/cython/persistence_graphical_tools.py')
-rw-r--r-- | src/cython/cython/persistence_graphical_tools.py | 233 |
1 files changed, 161 insertions, 72 deletions
diff --git a/src/cython/cython/persistence_graphical_tools.py b/src/cython/cython/persistence_graphical_tools.py index ead81d30..34803222 100644 --- a/src/cython/cython/persistence_graphical_tools.py +++ b/src/cython/cython/persistence_graphical_tools.py @@ -16,7 +16,8 @@ __author__ = "Vincent Rouvreau, Bertrand Michel" __copyright__ = "Copyright (C) 2016 Inria" __license__ = "MIT" -def __min_birth_max_death(persistence, band=0.): + +def __min_birth_max_death(persistence, band=0.0): """This function returns (min_birth, max_death) from the persistence. :param persistence: The persistence to plot. @@ -29,27 +30,47 @@ def __min_birth_max_death(persistence, band=0.): max_death = 0 min_birth = persistence[0][1][0] for interval in reversed(persistence): - if float(interval[1][1]) != float('inf'): + if float(interval[1][1]) != float("inf"): if float(interval[1][1]) > max_death: max_death = float(interval[1][1]) if float(interval[1][0]) > max_death: max_death = float(interval[1][0]) if float(interval[1][0]) < min_birth: min_birth = float(interval[1][0]) - if band > 0.: + if band > 0.0: max_death += band return (min_birth, max_death) + """ Only 13 colors for the palette """ -palette = ['#ff0000', '#00ff00', '#0000ff', '#00ffff', '#ff00ff', '#ffff00', - '#000000', '#880000', '#008800', '#000088', '#888800', '#880088', - '#008888'] - -def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6, - max_intervals=1000, max_barcodes=1000, - inf_delta=0.1, legend=False): +palette = [ + "#ff0000", + "#00ff00", + "#0000ff", + "#00ffff", + "#ff00ff", + "#ffff00", + "#000000", + "#880000", + "#008800", + "#000088", + "#888800", + "#880088", + "#008888", +] + + +def plot_persistence_barcode( + persistence=[], + persistence_file="", + alpha=0.6, + max_intervals=1000, + max_barcodes=1000, + inf_delta=0.1, + legend=False, +): """This function plots the persistence bar code from persistence values list or from a :doc:`persistence file <fileformats>`. @@ -78,11 +99,13 @@ def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6, import matplotlib.pyplot as plt import matplotlib.patches as mpatches - if persistence_file is not '': + if persistence_file is not "": if path.isfile(persistence_file): # Reset persistence persistence = [] - diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) + diag = read_persistence_intervals_grouped_by_dimension( + persistence_file=persistence_file + ) for key in diag.keys(): for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) @@ -91,44 +114,62 @@ def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6, return None if max_barcodes is not 1000: - print('Deprecated parameter. It has been replaced by max_intervals') + print("Deprecated parameter. It has been replaced by max_intervals") max_intervals = max_barcodes if max_intervals > 0 and max_intervals < len(persistence): # Sort by life time, then takes only the max_intervals elements - persistence = sorted(persistence, key=lambda life_time: life_time[1][1]-life_time[1][0], reverse=True)[:max_intervals] + persistence = sorted( + persistence, + key=lambda life_time: life_time[1][1] - life_time[1][0], + reverse=True, + )[:max_intervals] persistence = sorted(persistence, key=lambda birth: birth[1][0]) (min_birth, max_death) = __min_birth_max_death(persistence) ind = 0 - delta = ((max_death - min_birth) * inf_delta) + delta = (max_death - min_birth) * inf_delta # Replace infinity values with max_death + delta for bar code to be more # readable infinity = max_death + delta axis_start = min_birth - delta # Draw horizontal bars in loop for interval in reversed(persistence): - if float(interval[1][1]) != float('inf'): + if float(interval[1][1]) != float("inf"): # Finite death case - plt.barh(ind, (interval[1][1] - interval[1][0]), height=0.8, - left = interval[1][0], alpha=alpha, - color = palette[interval[0]], - linewidth=0) + plt.barh( + ind, + (interval[1][1] - interval[1][0]), + height=0.8, + left=interval[1][0], + alpha=alpha, + color=palette[interval[0]], + linewidth=0, + ) else: # Infinite death case for diagram to be nicer - plt.barh(ind, (infinity - interval[1][0]), height=0.8, - left = interval[1][0], alpha=alpha, - color = palette[interval[0]], - linewidth=0) + plt.barh( + ind, + (infinity - interval[1][0]), + height=0.8, + left=interval[1][0], + alpha=alpha, + color=palette[interval[0]], + linewidth=0, + ) ind = ind + 1 if legend: dimensions = list(set(item[0] for item in persistence)) - plt.legend(handles=[mpatches.Patch(color=palette[dim], - label=str(dim)) for dim in dimensions], - loc='lower right') - plt.title('Persistence barcode') + plt.legend( + handles=[ + mpatches.Patch(color=palette[dim], label=str(dim)) + for dim in dimensions + ], + loc="lower right", + ) + plt.title("Persistence barcode") # Ends plot on infinity value and starts a little bit before min_birth plt.axis([axis_start, infinity, 0, ind]) return plt @@ -136,8 +177,17 @@ def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6, except ImportError: print("This function is not available, you may be missing matplotlib.") -def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6, - band=0., max_intervals=1000, max_plots=1000, inf_delta=0.1, legend=False): + +def plot_persistence_diagram( + persistence=[], + persistence_file="", + alpha=0.6, + band=0.0, + max_intervals=1000, + max_plots=1000, + inf_delta=0.1, + legend=False, +): """This function plots the persistence diagram from persistence values list or from a :doc:`persistence file <fileformats>`. @@ -168,11 +218,13 @@ def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6, import matplotlib.pyplot as plt import matplotlib.patches as mpatches - if persistence_file is not '': + if persistence_file is not "": if path.isfile(persistence_file): # Reset persistence persistence = [] - diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) + diag = read_persistence_intervals_grouped_by_dimension( + persistence_file=persistence_file + ) for key in diag.keys(): for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) @@ -181,15 +233,19 @@ def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6, return None if max_plots is not 1000: - print('Deprecated parameter. It has been replaced by max_intervals') + print("Deprecated parameter. It has been replaced by max_intervals") max_intervals = max_plots if max_intervals > 0 and max_intervals < len(persistence): # Sort by life time, then takes only the max_intervals elements - persistence = sorted(persistence, key=lambda life_time: life_time[1][1]-life_time[1][0], reverse=True)[:max_intervals] + persistence = sorted( + persistence, + key=lambda life_time: life_time[1][1] - life_time[1][0], + reverse=True, + )[:max_intervals] (min_birth, max_death) = __min_birth_max_death(persistence, band) - delta = ((max_death - min_birth) * inf_delta) + delta = (max_death - min_birth) * inf_delta # Replace infinity values with max_death + delta for diagram to be more # readable infinity = max_death + delta @@ -198,31 +254,41 @@ def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6, # line display of equation : birth = death x = np.linspace(axis_start, infinity, 1000) # infinity line and text - plt.plot(x, x, color='k', linewidth=1.0) - plt.plot(x, [infinity] * len(x), linewidth=1.0, color='k', alpha=alpha) - plt.text(axis_start, infinity, r'$\infty$', color='k', alpha=alpha) + plt.plot(x, x, color="k", linewidth=1.0) + plt.plot(x, [infinity] * len(x), linewidth=1.0, color="k", alpha=alpha) + plt.text(axis_start, infinity, r"$\infty$", color="k", alpha=alpha) # bootstrap band - if band > 0.: - plt.fill_between(x, x, x+band, alpha=alpha, facecolor='red') + if band > 0.0: + plt.fill_between(x, x, x + band, alpha=alpha, facecolor="red") # Draw points in loop for interval in reversed(persistence): - if float(interval[1][1]) != float('inf'): + if float(interval[1][1]) != float("inf"): # Finite death case - plt.scatter(interval[1][0], interval[1][1], alpha=alpha, - color = palette[interval[0]]) + plt.scatter( + interval[1][0], + interval[1][1], + alpha=alpha, + color=palette[interval[0]], + ) else: # Infinite death case for diagram to be nicer - plt.scatter(interval[1][0], infinity, alpha=alpha, - color = palette[interval[0]]) + plt.scatter( + interval[1][0], infinity, alpha=alpha, color=palette[interval[0]] + ) if legend: dimensions = list(set(item[0] for item in persistence)) - plt.legend(handles=[mpatches.Patch(color=palette[dim], label=str(dim)) for dim in dimensions]) - - plt.title('Persistence diagram') - plt.xlabel('Birth') - plt.ylabel('Death') + plt.legend( + handles=[ + mpatches.Patch(color=palette[dim], label=str(dim)) + for dim in dimensions + ] + ) + + plt.title("Persistence diagram") + plt.xlabel("Birth") + plt.ylabel("Death") # Ends plot on infinity value and starts a little bit before min_birth plt.axis([axis_start, infinity, axis_start, infinity + delta]) return plt @@ -230,10 +296,17 @@ def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6, except ImportError: print("This function is not available, you may be missing matplotlib.") -def plot_persistence_density(persistence=[], persistence_file='', - nbins=300, bw_method=None, - max_intervals=1000, dimension=None, - cmap=None, legend=False): + +def plot_persistence_density( + persistence=[], + persistence_file="", + nbins=300, + bw_method=None, + max_intervals=1000, + dimension=None, + cmap=None, + legend=False, +): """This function plots the persistence density from persistence values list or from a :doc:`persistence file <fileformats>`. Be aware that this function does not distinguish the dimension, it is @@ -278,39 +351,53 @@ def plot_persistence_density(persistence=[], persistence_file='', import matplotlib.pyplot as plt from scipy.stats import kde - if persistence_file is not '': + if persistence_file is not "": if dimension is None: # All dimension case dimension = -1 if path.isfile(persistence_file): - persistence_dim = read_persistence_intervals_in_dimension(persistence_file=persistence_file, - only_this_dim=dimension) + persistence_dim = read_persistence_intervals_in_dimension( + persistence_file=persistence_file, only_this_dim=dimension + ) print(persistence_dim) else: print("file " + persistence_file + " not found.") return None if len(persistence) > 0: - persistence_dim = np.array([(dim_interval[1][0], dim_interval[1][1]) for dim_interval in persistence if (dim_interval[0] == dimension) or (dimension is None)]) - - persistence_dim = persistence_dim[np.isfinite(persistence_dim[:,1])] + persistence_dim = np.array( + [ + (dim_interval[1][0], dim_interval[1][1]) + for dim_interval in persistence + if (dim_interval[0] == dimension) or (dimension is None) + ] + ) + + persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])] if max_intervals > 0 and max_intervals < len(persistence_dim): # Sort by life time, then takes only the max_intervals elements - persistence_dim = np.array(sorted(persistence_dim, - key=lambda life_time: life_time[1]-life_time[0], - reverse=True)[:max_intervals]) + persistence_dim = np.array( + sorted( + persistence_dim, + key=lambda life_time: life_time[1] - life_time[0], + reverse=True, + )[:max_intervals] + ) # Set as numpy array birth and death (remove undefined values - inf and NaN) - birth = persistence_dim[:,0] - death = persistence_dim[:,1] + birth = persistence_dim[:, 0] + death = persistence_dim[:, 1] # line display of equation : birth = death x = np.linspace(death.min(), birth.max(), 1000) - plt.plot(x, x, color='k', linewidth=1.0) + plt.plot(x, x, color="k", linewidth=1.0) # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents - k = kde.gaussian_kde([birth,death], bw_method=bw_method) - xi, yi = np.mgrid[birth.min():birth.max():nbins*1j, death.min():death.max():nbins*1j] + k = kde.gaussian_kde([birth, death], bw_method=bw_method) + xi, yi = np.mgrid[ + birth.min() : birth.max() : nbins * 1j, + death.min() : death.max() : nbins * 1j, + ] zi = k(np.vstack([xi.flatten(), yi.flatten()])) # default cmap value cannot be done at argument definition level as matplotlib is not yet defined. @@ -322,10 +409,12 @@ def plot_persistence_density(persistence=[], persistence_file='', if legend: plt.colorbar() - plt.title('Persistence density') - plt.xlabel('Birth') - plt.ylabel('Death') + plt.title("Persistence density") + plt.xlabel("Birth") + plt.ylabel("Death") return plt except ImportError: - print("This function is not available, you may be missing matplotlib and/or scipy.") + print( + "This function is not available, you may be missing matplotlib and/or scipy." + ) |