diff options
Diffstat (limited to 'src/cython/cython/persistence_graphical_tools.py')
-rwxr-xr-x | src/cython/cython/persistence_graphical_tools.py | 55 |
1 files changed, 48 insertions, 7 deletions
diff --git a/src/cython/cython/persistence_graphical_tools.py b/src/cython/cython/persistence_graphical_tools.py index da709b8a..fb837e29 100755 --- a/src/cython/cython/persistence_graphical_tools.py +++ b/src/cython/cython/persistence_graphical_tools.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import numpy as np +import os """This file is part of the Gudhi Library. The Gudhi library (Geometric Understanding in Higher Dimensions) is a generic C++ @@ -23,7 +24,7 @@ import numpy as np along with this program. If not, see <http://www.gnu.org/licenses/>. """ -__author__ = "Vincent Rouvreau" +__author__ = "Vincent Rouvreau, Bertrand Michel" __copyright__ = "Copyright (C) 2016 INRIA" __license__ = "GPL v3" @@ -63,7 +64,7 @@ def show_palette_values(alpha=0.6): :param alpha: alpha value in [0.0, 1.0] for horizontal bars (default is 0.6). :type alpha: float. - :returns: plot -- An horizontal bar plot of dimensions color. + :returns: plot the dimension palette values. """ colors = [] for color in palette: @@ -74,18 +75,38 @@ def show_palette_values(alpha=0.6): plt.barh(y_pos, y_pos + 1, align='center', alpha=alpha, color=colors) plt.ylabel('Dimension') plt.title('Dimension palette values') + return plt - plt.show() - -def plot_persistence_barcode(persistence, alpha=0.6): +def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6, max_barcodes=0): """This function plots the persistence bar code. :param persistence: The persistence to plot. :type persistence: list of tuples(dimension, tuple(birth, death)). + :param persistence_file: A persistence file style name (reset persistence if both are set). + :type persistence_file: string :param alpha: alpha value in [0.0, 1.0] for horizontal bars (default is 0.6). :type alpha: float. + :param max_barcodes: number of maximal barcodes to be displayed + (persistence will be sorted by life time if max_barcodes is set) + :type max_barcodes: int. :returns: plot -- An horizontal bar plot of persistence. """ + if persistence_file is not '': + if os.path.isfile(persistence_file): + # Reset persistence + persistence = [] + diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) + for key in diag.keys(): + for persistence_interval in diag[key]: + persistence.append((key, persistence_interval)) + else: + print("file " + persistence_file + " not found.") + return None + + if max_barcodes > 0 and max_barcodes < len(persistence): + # Sort by life time, then takes only the max_plots elements + persistence = sorted(persistence, key=lambda life_time: life_time[1][1]-life_time[1][0], reverse=True)[:max_barcodes] + (min_birth, max_death) = __min_birth_max_death(persistence) ind = 0 delta = ((max_death - min_birth) / 10.0) @@ -110,19 +131,39 @@ def plot_persistence_barcode(persistence, alpha=0.6): plt.title('Persistence barcode') # Ends plot on infinity value and starts a little bit before min_birth plt.axis([axis_start, infinity, 0, ind]) - plt.show() + return plt -def plot_persistence_diagram(persistence, alpha=0.6, band_boot=0.): +def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6, band_boot=0., max_plots=0): """This function plots the persistence diagram with an optional confidence band. :param persistence: The persistence to plot. :type persistence: list of tuples(dimension, tuple(birth, death)). + :param persistence_file: A persistence file style name (reset persistence if both are set). + :type persistence_file: string :param alpha: alpha value in [0.0, 1.0] for points and horizontal infinity line (default is 0.6). :type alpha: float. :param band_boot: bootstrap band (not displayed if :math:`\leq` 0.) :type band_boot: float. + :param max_plots: number of maximal plots to be displayed + :type max_plots: int. :returns: plot -- A diagram plot of persistence. """ + if persistence_file is not '': + if os.path.isfile(persistence_file): + # Reset persistence + persistence = [] + diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) + for key in diag.keys(): + for persistence_interval in diag[key]: + persistence.append((key, persistence_interval)) + else: + print("file " + persistence_file + " not found.") + return None + + if max_plots > 0 and max_plots < len(persistence): + # Sort by life time, then takes only the max_plots elements + persistence = sorted(persistence, key=lambda life_time: life_time[1][1]-life_time[1][0], reverse=True)[:max_plots] + (min_birth, max_death) = __min_birth_max_death(persistence, band_boot) ind = 0 delta = ((max_death - min_birth) / 10.0) |