From 6bc1f27bbab03139718db674f98d748a7aeaced3 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Wed, 20 Nov 2019 14:12:01 +0100 Subject: Use matplotlib axes to be able to subplot persistence graphical tools --- src/python/gudhi/persistence_graphical_tools.py | 90 +++++++++++++++---------- 1 file changed, 54 insertions(+), 36 deletions(-) (limited to 'src/python/gudhi') diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 23725ca7..c9dab323 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -51,7 +51,8 @@ def plot_persistence_barcode( max_barcodes=1000, inf_delta=0.1, legend=False, - colormap=None + colormap=None, + axes=None ): """This function plots the persistence bar code from persistence values list or from a :doc:`persistence file `. @@ -77,8 +78,10 @@ def plot_persistence_barcode( :param colormap: A matplotlib-like qualitative colormaps. Default is None which means :code:`matplotlib.cm.Set1.colors`. :type colormap: tuple of colors (3-tuple of float between 0. and 1.). - :returns: A matplotlib object containing horizontal bar plot of persistence - (launch `show()` method on it to display it). + :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on + a new set of axes. + :type axes: `matplotlib.axes.Axes` + :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn. """ try: import matplotlib.pyplot as plt @@ -112,6 +115,8 @@ def plot_persistence_barcode( if colormap == None: colormap = plt.cm.Set1.colors + if axes == None: + fig, axes = plt.subplots(1, 1) persistence = sorted(persistence, key=lambda birth: birth[1][0]) @@ -126,7 +131,7 @@ def plot_persistence_barcode( for interval in reversed(persistence): if float(interval[1][1]) != float("inf"): # Finite death case - plt.barh( + axes.barh( ind, (interval[1][1] - interval[1][0]), height=0.8, @@ -137,7 +142,7 @@ def plot_persistence_barcode( ) else: # Infinite death case for diagram to be nicer - plt.barh( + axes.barh( ind, (infinity - interval[1][0]), height=0.8, @@ -150,17 +155,19 @@ def plot_persistence_barcode( if legend: dimensions = list(set(item[0] for item in persistence)) - plt.legend( + axes.legend( handles=[ mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions ], loc="lower right", ) - plt.title("Persistence barcode") + + axes.set_title("Persistence barcode") + # Ends plot on infinity value and starts a little bit before min_birth - plt.axis([axis_start, infinity, 0, ind]) - return plt + axes.axis([axis_start, infinity, 0, ind]) + return axes except ImportError: print("This function is not available, you may be missing matplotlib.") @@ -175,7 +182,8 @@ def plot_persistence_diagram( max_plots=1000, inf_delta=0.1, legend=False, - colormap=None + colormap=None, + axes=None ): """This function plots the persistence diagram from persistence values list or from a :doc:`persistence file `. @@ -203,8 +211,10 @@ def plot_persistence_diagram( :param colormap: A matplotlib-like qualitative colormaps. Default is None which means :code:`matplotlib.cm.Set1.colors`. :type colormap: tuple of colors (3-tuple of float between 0. and 1.). - :returns: A matplotlib object containing diagram plot of persistence - (launch `show()` method on it to display it). + :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on + a new set of axes. + :type axes: `matplotlib.axes.Axes` + :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn. """ try: import matplotlib.pyplot as plt @@ -238,6 +248,8 @@ def plot_persistence_diagram( if colormap == None: colormap = plt.cm.Set1.colors + if axes == None: + fig, axes = plt.subplots(1, 1) (min_birth, max_death) = __min_birth_max_death(persistence, band) delta = (max_death - min_birth) * inf_delta @@ -249,18 +261,18 @@ def plot_persistence_diagram( # line display of equation : birth = death x = np.linspace(axis_start, infinity, 1000) # infinity line and text - plt.plot(x, x, color="k", linewidth=1.0) - plt.plot(x, [infinity] * len(x), linewidth=1.0, color="k", alpha=alpha) - plt.text(axis_start, infinity, r"$\infty$", color="k", alpha=alpha) + axes.plot(x, x, color="k", linewidth=1.0) + axes.plot(x, [infinity] * len(x), linewidth=1.0, color="k", alpha=alpha) + axes.text(axis_start, infinity, r"$\infty$", color="k", alpha=alpha) # bootstrap band if band > 0.0: - plt.fill_between(x, x, x + band, alpha=alpha, facecolor="red") + axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red") # Draw points in loop for interval in reversed(persistence): if float(interval[1][1]) != float("inf"): # Finite death case - plt.scatter( + axes.scatter( interval[1][0], interval[1][1], alpha=alpha, @@ -268,25 +280,25 @@ def plot_persistence_diagram( ) else: # Infinite death case for diagram to be nicer - plt.scatter( + axes.scatter( interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]] ) if legend: dimensions = list(set(item[0] for item in persistence)) - plt.legend( + axes.legend( handles=[ mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions ] ) - plt.title("Persistence diagram") - plt.xlabel("Birth") - plt.ylabel("Death") + axes.set_xlabel("Birth") + axes.set_ylabel("Death") # Ends plot on infinity value and starts a little bit before min_birth - plt.axis([axis_start, infinity, axis_start, infinity + delta]) - return plt + axes.axis([axis_start, infinity, axis_start, infinity + delta]) + axes.set_title("Persistence diagram") + return axes except ImportError: print("This function is not available, you may be missing matplotlib.") @@ -301,6 +313,7 @@ def plot_persistence_density( dimension=None, cmap=None, legend=False, + axes=None ): """This function plots the persistence density from persistence values list or from a :doc:`persistence file `. Be @@ -339,8 +352,10 @@ def plot_persistence_density( :type cmap: cf. matplotlib colormap. :param legend: Display the color bar values (default is False). :type legend: boolean. - :returns: A matplotlib object containing diagram plot of persistence - (launch `show()` method on it to display it). + :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on + a new set of axes. + :type axes: `matplotlib.axes.Axes` + :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn. """ try: import matplotlib.pyplot as plt @@ -383,9 +398,15 @@ def plot_persistence_density( birth = persistence_dim[:, 0] death = persistence_dim[:, 1] + # default cmap value cannot be done at argument definition level as matplotlib is not yet defined. + if cmap is None: + cmap = plt.cm.hot_r + if axes == None: + fig, axes = plt.subplots(1, 1) + # line display of equation : birth = death x = np.linspace(death.min(), birth.max(), 1000) - plt.plot(x, x, color="k", linewidth=1.0) + axes.plot(x, x, color="k", linewidth=1.0) # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents k = kde.gaussian_kde([birth, death], bw_method=bw_method) @@ -395,19 +416,16 @@ def plot_persistence_density( ] zi = k(np.vstack([xi.flatten(), yi.flatten()])) - # default cmap value cannot be done at argument definition level as matplotlib is not yet defined. - if cmap is None: - cmap = plt.cm.hot_r # Make the plot - plt.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap) + axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap) if legend: - plt.colorbar() + axes.colorbar() - plt.title("Persistence density") - plt.xlabel("Birth") - plt.ylabel("Death") - return plt + axes.set_xlabel("Birth") + axes.set_ylabel("Death") + axes.set_title("Persistence density") + return axes except ImportError: print( -- cgit v1.2.3