From 5bc96fdf837e4acb80b1333b9db63ddf5802edc8 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 9 Mar 2020 12:12:13 +0100 Subject: removed infty line plot in plot_diagram if no pts at infty --- src/python/gudhi/persistence_graphical_tools.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) (limited to 'src/python') diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 8c38b684..cc3db467 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -296,17 +296,6 @@ def plot_persistence_diagram( axis_end = max_death + delta / 2 axis_start = min_birth - delta - # infinity line and text - axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") - axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha) - # Infinity label - yt = axes.get_yticks() - yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity - yt = np.append(yt, infinity) - ytl = ["%.3f" % e for e in yt] # to avoid float precision error - ytl[-1] = r'$+\infty$' - axes.set_yticks(yt) - axes.set_yticklabels(ytl) # bootstrap band if band > 0.0: x = np.linspace(axis_start, infinity, 1000) @@ -315,6 +304,7 @@ def plot_persistence_diagram( if greyblock: axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey')) # Draw points in loop + pts_at_infty = False # Records presence of pts at infty for interval in reversed(persistence): if float(interval[1][1]) != float("inf"): # Finite death case @@ -325,10 +315,23 @@ def plot_persistence_diagram( color=colormap[interval[0]], ) else: + pts_at_infty = True # Infinite death case for diagram to be nicer axes.scatter( interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]] ) + if pts_at_infty: + # infinity line and text + axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") + axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha) + # Infinity label + yt = axes.get_yticks() + yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity + yt = np.append(yt, infinity) + ytl = ["%.3f" % e for e in yt] # to avoid float precision error + ytl[-1] = r'$+\infty$' + axes.set_yticks(yt) + axes.set_yticklabels(ytl) if legend: dimensions = list(set(item[0] for item in persistence)) -- cgit v1.2.3