diff options
Diffstat (limited to 'src/python/gudhi/persistence_graphical_tools.py')
-rw-r--r-- | src/python/gudhi/persistence_graphical_tools.py | 163 |
1 files changed, 86 insertions, 77 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 181bc8ea..246280de 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -1,3 +1,12 @@ +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# Author(s): Vincent Rouvreau, Bertrand Michel +# +# Copyright (C) 2016 Inria +# +# Modification(s): +# - YYYY/MM Author: Description of the modification + from os import path from math import isfinite import numpy as np @@ -5,16 +14,6 @@ import numpy as np from gudhi.reader_utils import read_persistence_intervals_in_dimension from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension -""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. - See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. - Author(s): Vincent Rouvreau, Bertrand Michel - - Copyright (C) 2016 Inria - - Modification(s): - - YYYY/MM Author: Description of the modification -""" - __author__ = "Vincent Rouvreau, Bertrand Michel" __copyright__ = "Copyright (C) 2016 Inria" __license__ = "MIT" @@ -44,27 +43,6 @@ def __min_birth_max_death(persistence, band=0.0): max_death += band return (min_birth, max_death) - -""" -Only 13 colors for the palette -""" -palette = [ - "#ff0000", - "#00ff00", - "#0000ff", - "#00ffff", - "#ff00ff", - "#ffff00", - "#000000", - "#880000", - "#008800", - "#000088", - "#888800", - "#880088", - "#008888", -] - - def plot_persistence_barcode( persistence=[], persistence_file="", @@ -73,6 +51,8 @@ def plot_persistence_barcode( max_barcodes=1000, inf_delta=0.1, legend=False, + colormap=None, + axes=None ): """This function plots the persistence bar code from persistence values list or from a :doc:`persistence file <fileformats>`. @@ -95,14 +75,19 @@ def plot_persistence_barcode( :type inf_delta: float. :param legend: Display the dimension color legend (default is False). :type legend: boolean. - :returns: A matplotlib object containing horizontal bar plot of persistence - (launch `show()` method on it to display it). + :param colormap: A matplotlib-like qualitative colormaps. Default is None + which means :code:`matplotlib.cm.Set1.colors`. + :type colormap: tuple of colors (3-tuple of float between 0. and 1.). + :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on + a new set of axes. + :type axes: `matplotlib.axes.Axes` + :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn. """ try: import matplotlib.pyplot as plt import matplotlib.patches as mpatches - if persistence_file is not "": + if persistence_file != "": if path.isfile(persistence_file): # Reset persistence persistence = [] @@ -116,7 +101,7 @@ def plot_persistence_barcode( print("file " + persistence_file + " not found.") return None - if max_barcodes is not 1000: + if max_barcodes != 1000: print("Deprecated parameter. It has been replaced by max_intervals") max_intervals = max_barcodes @@ -127,6 +112,11 @@ def plot_persistence_barcode( key=lambda life_time: life_time[1][1] - life_time[1][0], reverse=True, )[:max_intervals] + + if colormap == None: + colormap = plt.cm.Set1.colors + if axes == None: + fig, axes = plt.subplots(1, 1) persistence = sorted(persistence, key=lambda birth: birth[1][0]) @@ -141,41 +131,43 @@ def plot_persistence_barcode( for interval in reversed(persistence): if float(interval[1][1]) != float("inf"): # Finite death case - plt.barh( + axes.barh( ind, (interval[1][1] - interval[1][0]), height=0.8, left=interval[1][0], alpha=alpha, - color=palette[interval[0]], + color=colormap[interval[0]], linewidth=0, ) else: # Infinite death case for diagram to be nicer - plt.barh( + axes.barh( ind, (infinity - interval[1][0]), height=0.8, left=interval[1][0], alpha=alpha, - color=palette[interval[0]], + color=colormap[interval[0]], linewidth=0, ) ind = ind + 1 if legend: dimensions = list(set(item[0] for item in persistence)) - plt.legend( + axes.legend( handles=[ - mpatches.Patch(color=palette[dim], label=str(dim)) + mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions ], loc="lower right", ) - plt.title("Persistence barcode") + + axes.set_title("Persistence barcode") + # Ends plot on infinity value and starts a little bit before min_birth - plt.axis([axis_start, infinity, 0, ind]) - return plt + axes.axis([axis_start, infinity, 0, ind]) + return axes except ImportError: print("This function is not available, you may be missing matplotlib.") @@ -190,6 +182,8 @@ def plot_persistence_diagram( max_plots=1000, inf_delta=0.1, legend=False, + colormap=None, + axes=None ): """This function plots the persistence diagram from persistence values list or from a :doc:`persistence file <fileformats>`. @@ -214,14 +208,19 @@ def plot_persistence_diagram( :type inf_delta: float. :param legend: Display the dimension color legend (default is False). :type legend: boolean. - :returns: A matplotlib object containing diagram plot of persistence - (launch `show()` method on it to display it). + :param colormap: A matplotlib-like qualitative colormaps. Default is None + which means :code:`matplotlib.cm.Set1.colors`. + :type colormap: tuple of colors (3-tuple of float between 0. and 1.). + :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on + a new set of axes. + :type axes: `matplotlib.axes.Axes` + :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn. """ try: import matplotlib.pyplot as plt import matplotlib.patches as mpatches - if persistence_file is not "": + if persistence_file != "": if path.isfile(persistence_file): # Reset persistence persistence = [] @@ -235,7 +234,7 @@ def plot_persistence_diagram( print("file " + persistence_file + " not found.") return None - if max_plots is not 1000: + if max_plots != 1000: print("Deprecated parameter. It has been replaced by max_intervals") max_intervals = max_plots @@ -247,6 +246,11 @@ def plot_persistence_diagram( reverse=True, )[:max_intervals] + if colormap == None: + colormap = plt.cm.Set1.colors + if axes == None: + fig, axes = plt.subplots(1, 1) + (min_birth, max_death) = __min_birth_max_death(persistence, band) delta = (max_death - min_birth) * inf_delta # Replace infinity values with max_death + delta for diagram to be more @@ -257,44 +261,44 @@ def plot_persistence_diagram( # line display of equation : birth = death x = np.linspace(axis_start, infinity, 1000) # infinity line and text - plt.plot(x, x, color="k", linewidth=1.0) - plt.plot(x, [infinity] * len(x), linewidth=1.0, color="k", alpha=alpha) - plt.text(axis_start, infinity, r"$\infty$", color="k", alpha=alpha) + axes.plot(x, x, color="k", linewidth=1.0) + axes.plot(x, [infinity] * len(x), linewidth=1.0, color="k", alpha=alpha) + axes.text(axis_start, infinity, r"$\infty$", color="k", alpha=alpha) # bootstrap band if band > 0.0: - plt.fill_between(x, x, x + band, alpha=alpha, facecolor="red") + axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red") # Draw points in loop for interval in reversed(persistence): if float(interval[1][1]) != float("inf"): # Finite death case - plt.scatter( + axes.scatter( interval[1][0], interval[1][1], alpha=alpha, - color=palette[interval[0]], + color=colormap[interval[0]], ) else: # Infinite death case for diagram to be nicer - plt.scatter( - interval[1][0], infinity, alpha=alpha, color=palette[interval[0]] + axes.scatter( + interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]] ) if legend: dimensions = list(set(item[0] for item in persistence)) - plt.legend( + axes.legend( handles=[ - mpatches.Patch(color=palette[dim], label=str(dim)) + mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions ] ) - plt.title("Persistence diagram") - plt.xlabel("Birth") - plt.ylabel("Death") + axes.set_xlabel("Birth") + axes.set_ylabel("Death") # Ends plot on infinity value and starts a little bit before min_birth - plt.axis([axis_start, infinity, axis_start, infinity + delta]) - return plt + axes.axis([axis_start, infinity, axis_start, infinity + delta]) + axes.set_title("Persistence diagram") + return axes except ImportError: print("This function is not available, you may be missing matplotlib.") @@ -309,6 +313,7 @@ def plot_persistence_density( dimension=None, cmap=None, legend=False, + axes=None ): """This function plots the persistence density from persistence values list or from a :doc:`persistence file <fileformats>`. Be @@ -347,14 +352,16 @@ def plot_persistence_density( :type cmap: cf. matplotlib colormap. :param legend: Display the color bar values (default is False). :type legend: boolean. - :returns: A matplotlib object containing diagram plot of persistence - (launch `show()` method on it to display it). + :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on + a new set of axes. + :type axes: `matplotlib.axes.Axes` + :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn. """ try: import matplotlib.pyplot as plt from scipy.stats import kde - if persistence_file is not "": + if persistence_file != "": if dimension is None: # All dimension case dimension = -1 @@ -362,7 +369,6 @@ def plot_persistence_density( persistence_dim = read_persistence_intervals_in_dimension( persistence_file=persistence_file, only_this_dim=dimension ) - print(persistence_dim) else: print("file " + persistence_file + " not found.") return None @@ -391,9 +397,15 @@ def plot_persistence_density( birth = persistence_dim[:, 0] death = persistence_dim[:, 1] + # default cmap value cannot be done at argument definition level as matplotlib is not yet defined. + if cmap is None: + cmap = plt.cm.hot_r + if axes == None: + fig, axes = plt.subplots(1, 1) + # line display of equation : birth = death x = np.linspace(death.min(), birth.max(), 1000) - plt.plot(x, x, color="k", linewidth=1.0) + axes.plot(x, x, color="k", linewidth=1.0) # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents k = kde.gaussian_kde([birth, death], bw_method=bw_method) @@ -403,19 +415,16 @@ def plot_persistence_density( ] zi = k(np.vstack([xi.flatten(), yi.flatten()])) - # default cmap value cannot be done at argument definition level as matplotlib is not yet defined. - if cmap is None: - cmap = plt.cm.hot_r # Make the plot - plt.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap) + img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap) if legend: - plt.colorbar() + plt.colorbar(img, ax=axes) - plt.title("Persistence density") - plt.xlabel("Birth") - plt.ylabel("Death") - return plt + axes.set_xlabel("Birth") + axes.set_ylabel("Death") + axes.set_title("Persistence density") + return axes except ImportError: print( |