From cdcd2904a1c682625670a62608fd781bfd571516 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Tue, 25 Feb 2020 18:19:24 +0100 Subject: solved scale issue and removed title/aspect as functions return ax --- src/python/gudhi/persistence_graphical_tools.py | 77 +++++++++++++++++++------ 1 file changed, 60 insertions(+), 17 deletions(-) (limited to 'src/python') diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 246280de..8ddfdba8 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -5,6 +5,7 @@ # Copyright (C) 2016 Inria # # Modification(s): +# - 2020/02 Theo Lacombe: Added more options for improved rendering and more flexibility. # - YYYY/MM Author: Description of the modification from os import path @@ -43,6 +44,7 @@ def __min_birth_max_death(persistence, band=0.0): max_death += band return (min_birth, max_death) + def plot_persistence_barcode( persistence=[], persistence_file="", @@ -52,7 +54,8 @@ def plot_persistence_barcode( inf_delta=0.1, legend=False, colormap=None, - axes=None + axes=None, + fontsize=16, ): """This function plots the persistence bar code from persistence values list or from a :doc:`persistence file `. @@ -81,11 +84,16 @@ def plot_persistence_barcode( :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on a new set of axes. :type axes: `matplotlib.axes.Axes` + :param fontsize: Fontsize to use in axis. + :type fontsize: int :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn. """ try: import matplotlib.pyplot as plt import matplotlib.patches as mpatches + from matplotlib import rc + plt.rc('text', usetex=True) + plt.rc('font', family='serif') if persistence_file != "": if path.isfile(persistence_file): @@ -163,7 +171,7 @@ def plot_persistence_barcode( loc="lower right", ) - axes.set_title("Persistence barcode") + axes.set_title("Persistence barcode", fontsize=fontsize) # Ends plot on infinity value and starts a little bit before min_birth axes.axis([axis_start, infinity, 0, ind]) @@ -183,7 +191,9 @@ def plot_persistence_diagram( inf_delta=0.1, legend=False, colormap=None, - axes=None + axes=None, + fontsize=16, + greyblock=True ): """This function plots the persistence diagram from persistence values list or from a :doc:`persistence file `. @@ -214,11 +224,19 @@ def plot_persistence_diagram( :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on a new set of axes. :type axes: `matplotlib.axes.Axes` + :param fontsize: Fontsize to use in axis. + :type fontsize: int + :param greyblock: if we want to plot a grey patch on the lower half plane for nicer rendering. Default True. + :type greyblock: boolean :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn. """ try: import matplotlib.pyplot as plt import matplotlib.patches as mpatches + from matplotlib import rc + plt.rc('text', usetex=True) + plt.rc('font', family='serif') + if persistence_file != "": if path.isfile(persistence_file): @@ -256,18 +274,27 @@ def plot_persistence_diagram( # Replace infinity values with max_death + delta for diagram to be more # readable infinity = max_death + delta + axis_end = max_death + delta / 2 axis_start = min_birth - delta - # line display of equation : birth = death - x = np.linspace(axis_start, infinity, 1000) # infinity line and text - axes.plot(x, x, color="k", linewidth=1.0) - axes.plot(x, [infinity] * len(x), linewidth=1.0, color="k", alpha=alpha) - axes.text(axis_start, infinity, r"$\infty$", color="k", alpha=alpha) + axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") + axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha) + # Infinity label + yt = axes.get_yticks() + yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity + yt = np.append(yt, infinity) + ytl = ["%.3f" % e for e in yt] # to avoid float precision error + ytl[-1] = r'$+\infty$' + axes.set_yticks(yt) + axes.set_yticklabels(ytl) # bootstrap band if band > 0.0: + x = np.linspace(axis_start, infinity, 1000) axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red") - + # lower diag patch + if greyblock: + axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey')) # Draw points in loop for interval in reversed(persistence): if float(interval[1][1]) != float("inf"): @@ -293,11 +320,11 @@ def plot_persistence_diagram( ] ) - axes.set_xlabel("Birth") - axes.set_ylabel("Death") + axes.set_xlabel("Birth", fontsize=fontsize) + axes.set_ylabel("Death", fontsize=fontsize) + axes.set_title("Persistence diagram", fontsize=fontsize) # Ends plot on infinity value and starts a little bit before min_birth - axes.axis([axis_start, infinity, axis_start, infinity + delta]) - axes.set_title("Persistence diagram") + axes.axis([axis_start, axis_end, axis_start, infinity + delta/2]) return axes except ImportError: @@ -313,7 +340,9 @@ def plot_persistence_density( dimension=None, cmap=None, legend=False, - axes=None + axes=None, + fontsize=16, + greyblock=True ): """This function plots the persistence density from persistence values list or from a :doc:`persistence file `. Be @@ -355,11 +384,21 @@ def plot_persistence_density( :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on a new set of axes. :type axes: `matplotlib.axes.Axes` + :param fontsize: Fontsize to use in axis. + :type fontsize: int + :param greyblock: if we want to plot a grey patch on the lower half plane + for nicer rendering. Default True. + :type greyblock: boolean :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn. """ try: import matplotlib.pyplot as plt + import matplotlib.patches as mpatches from scipy.stats import kde + from matplotlib import rc + plt.rc('text', usetex=True) + plt.rc('font', family='serif') + if persistence_file != "": if dimension is None: @@ -418,12 +457,16 @@ def plot_persistence_density( # Make the plot img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap) + if greyblock: + axes.add_patch(mpatches.Polygon([[birth.min(), birth.min()], [death.max(), birth.min()], [death.max(), death.max()]], fill=True, color='lightgrey')) + if legend: plt.colorbar(img, ax=axes) - axes.set_xlabel("Birth") - axes.set_ylabel("Death") - axes.set_title("Persistence density") + axes.set_xlabel("Birth", fontsize=fontsize) + axes.set_ylabel("Death", fontsize=fontsize) + axes.set_title("Persistence density", fontsize=fontsize) + return axes except ImportError: -- cgit v1.2.3