"""This file is part of the Gudhi Library. The Gudhi library (Geometric Understanding in Higher Dimensions) is a generic C++ library for computational topology. Author(s): Vincent Rouvreau, Bertrand Michel Copyright (C) 2016 Inria This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . """ __author__ = "Vincent Rouvreau, Bertrand Michel" __copyright__ = "Copyright (C) 2016 Inria" __license__ = "GPL v3" try: import matplotlib.pyplot as plt import matplotlib.patches as mpatches import numpy as np import os def __min_birth_max_death(persistence, band=0.): """This function returns (min_birth, max_death) from the persistence. :param persistence: The persistence to plot. :type persistence: list of tuples(dimension, tuple(birth, death)). :param band: band :type band: float. :returns: (float, float) -- (min_birth, max_death). """ # Look for minimum birth date and maximum death date for plot optimisation max_death = 0 min_birth = persistence[0][1][0] for interval in reversed(persistence): if float(interval[1][1]) != float('inf'): if float(interval[1][1]) > max_death: max_death = float(interval[1][1]) if float(interval[1][0]) > max_death: max_death = float(interval[1][0]) if float(interval[1][0]) < min_birth: min_birth = float(interval[1][0]) if band > 0.: max_death += band return (min_birth, max_death) """ Only 13 colors for the palette """ palette = ['#ff0000', '#00ff00', '#0000ff', '#00ffff', '#ff00ff', '#ffff00', '#000000', '#880000', '#008800', '#000088', '#888800', '#880088', '#008888'] def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6, max_barcodes=1000, inf_delta=0.1, legend=False): """This function plots the persistence bar code from persistence values list or from a :doc:`persistence file `. :param persistence: Persistence values list. :type persistence: list of tuples(dimension, tuple(birth, death)). :param persistence_file: A :doc:`persistence file ` style name (reset persistence if both are set). :type persistence_file: string :param alpha: barcode transparency value (0.0 transparent through 1.0 opaque - default is 0.6). :type alpha: float. :param max_barcodes: number of maximal barcodes to be displayed. Set it to 0 to see all, Default value is 1000. (persistence will be sorted by life time if max_barcodes is set) :type max_barcodes: int. :param inf_delta: Infinity is placed at ((max_death - min_birth) x inf_delta). A reasonable value is between 0.05 and 0.5 - default is 0.1. :type inf_delta: float. :returns: A matplotlib object containing horizontal bar plot of persistence (launch `show()` method on it to display it). """ if persistence_file is not '': if os.path.isfile(persistence_file): # Reset persistence persistence = [] diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) for key in diag.keys(): for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) else: print("file " + persistence_file + " not found.") return None if max_barcodes > 0 and max_barcodes < len(persistence): # Sort by life time, then takes only the max_plots elements persistence = sorted(persistence, key=lambda life_time: life_time[1][1]-life_time[1][0], reverse=True)[:max_barcodes] persistence = sorted(persistence, key=lambda birth: birth[1][0]) (min_birth, max_death) = __min_birth_max_death(persistence) ind = 0 delta = ((max_death - min_birth) * inf_delta) # Replace infinity values with max_death + delta for bar code to be more # readable infinity = max_death + delta axis_start = min_birth - delta # Draw horizontal bars in loop for interval in reversed(persistence): if float(interval[1][1]) != float('inf'): # Finite death case plt.barh(ind, (interval[1][1] - interval[1][0]), height=0.8, left = interval[1][0], alpha=alpha, color = palette[interval[0]], linewidth=0) else: # Infinite death case for diagram to be nicer plt.barh(ind, (infinity - interval[1][0]), height=0.8, left = interval[1][0], alpha=alpha, color = palette[interval[0]], linewidth=0) ind = ind + 1 if legend: dimensions = list(set(item[0] for item in persistence)) plt.legend(handles=[mpatches.Patch(color=palette[dim], label=str(dim)) for dim in dimensions], loc='lower right') plt.title('Persistence barcode') # Ends plot on infinity value and starts a little bit before min_birth plt.axis([axis_start, infinity, 0, ind]) return plt def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6, band=0., max_plots=1000, inf_delta=0.1, legend=False): """This function plots the persistence diagram from persistence values list or from a :doc:`persistence file `. :param persistence: Persistence values list. :type persistence: list of tuples(dimension, tuple(birth, death)). :param persistence_file: A :doc:`persistence file ` style name (reset persistence if both are set). :type persistence_file: string :param alpha: plot transparency value (0.0 transparent through 1.0 opaque - default is 0.6). :type alpha: float. :param band: band (not displayed if :math:`\leq` 0. - default is 0.) :type band: float. :param max_plots: number of maximal plots to be displayed Set it to 0 to see all, Default value is 1000. (persistence will be sorted by life time if max_plots is set) :type max_plots: int. :param inf_delta: Infinity is placed at ((max_death - min_birth) x inf_delta). A reasonable value is between 0.05 and 0.5 - default is 0.1. :type inf_delta: float. :returns: A matplotlib object containing diagram plot of persistence (launch `show()` method on it to display it). """ if persistence_file is not '': if os.path.isfile(persistence_file): # Reset persistence persistence = [] diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) for key in diag.keys(): for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) else: print("file " + persistence_file + " not found.") return None if max_plots > 0 and max_plots < len(persistence): # Sort by life time, then takes only the max_plots elements persistence = sorted(persistence, key=lambda life_time: life_time[1][1]-life_time[1][0], reverse=True)[:max_plots] (min_birth, max_death) = __min_birth_max_death(persistence, band) ind = 0 delta = ((max_death - min_birth) * inf_delta) # Replace infinity values with max_death + delta for diagram to be more # readable infinity = max_death + delta axis_start = min_birth - delta # line display of equation : birth = death x = np.linspace(axis_start, infinity, 1000) # infinity line and text plt.plot(x, x, color='k', linewidth=1.0) plt.plot(x, [infinity] * len(x), linewidth=1.0, color='k', alpha=alpha) plt.text(axis_start, infinity, r'$\infty$', color='k', alpha=alpha) # bootstrap band if band > 0.: plt.fill_between(x, x, x+band, alpha=alpha, facecolor='red') # Draw points in loop for interval in reversed(persistence): if float(interval[1][1]) != float('inf'): # Finite death case plt.scatter(interval[1][0], interval[1][1], alpha=alpha, color = palette[interval[0]]) else: # Infinite death case for diagram to be nicer plt.scatter(interval[1][0], infinity, alpha=alpha, color = palette[interval[0]]) ind = ind + 1 if legend: dimensions = list(set(item[0] for item in persistence)) plt.legend(handles=[mpatches.Patch(color=palette[dim], label=str(dim)) for dim in dimensions]) plt.title('Persistence diagram') plt.xlabel('Birth') plt.ylabel('Death') # Ends plot on infinity value and starts a little bit before min_birth plt.axis([axis_start, infinity, axis_start, infinity + delta]) return plt except ImportError: # Continue in case of import error, functions won't be available pass