From 79278cfc0aecbfdbf8a66e79ccef44534abb2399 Mon Sep 17 00:00:00 2001 From: Vincent Rouvreau Date: Fri, 29 Apr 2022 17:48:19 +0200 Subject: plot_persistence_diagram and plot_persistence_barcode improvements --- src/python/gudhi/persistence_graphical_tools.py | 68 ++++++------------------- 1 file changed, 16 insertions(+), 52 deletions(-) diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 3c166a56..604018d1 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -105,7 +105,7 @@ def plot_persistence_barcode( persistence=[], persistence_file="", alpha=0.6, - max_intervals=1000, + max_intervals=20000, inf_delta=0.1, legend=False, colormap=None, @@ -127,7 +127,7 @@ def plot_persistence_barcode( :type alpha: float. :param max_intervals: maximal number of intervals to display. Selected intervals are those with the longest life time. Set it - to 0 to see all. Default value is 1000. + to 0 to see all. Default value is 20000. :type max_intervals: int. :param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x inf_delta)` above :code:`max_death` value. A reasonable value is @@ -189,36 +189,11 @@ def plot_persistence_barcode( _, axes = plt.subplots(1, 1) if colormap == None: colormap = plt.cm.Set1.colors - ind = 0 - - # Draw horizontal bars in loop - for interval in reversed(persistence): - try: - if float(interval[1][1]) != float("inf"): - # Finite death case - axes.barh( - ind, - (interval[1][1] - interval[1][0]), - height=0.8, - left=interval[1][0], - alpha=alpha, - color=colormap[interval[0]], - linewidth=0, - ) - else: - # Infinite death case for diagram to be nicer - axes.barh( - ind, - (infinity - interval[1][0]), - height=0.8, - left=interval[1][0], - alpha=alpha, - color=colormap[interval[0]], - linewidth=0, - ) - ind = ind + 1 - except IndexError: - pass + + x=[birth for (dim,(birth,death)) in persistence] + y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence] + c=[colormap[dim] for (dim,(birth,death)) in persistence] + axes.barh(list(reversed(range(len(x)))), y, height=0.8, left=x, alpha=alpha, color=c, linewidth=0) if legend: dimensions = list(set(item[0] for item in persistence)) @@ -229,8 +204,8 @@ def plot_persistence_barcode( axes.set_title("Persistence barcode", fontsize=fontsize) # Ends plot on infinity value and starts a little bit before min_birth - if ind != 0: - axes.axis([axis_start, infinity, 0, ind]) + if len(x) != 0: + axes.axis([axis_start, infinity, 0, len(x)]) return axes except ImportError as import_error: @@ -242,7 +217,7 @@ def plot_persistence_diagram( persistence_file="", alpha=0.6, band=0.0, - max_intervals=1000, + max_intervals=1000000, inf_delta=0.1, legend=False, colormap=None, @@ -266,7 +241,7 @@ def plot_persistence_diagram( :type band: float. :param max_intervals: maximal number of intervals to display. Selected intervals are those with the longest life time. Set it - to 0 to see all. Default value is 1000. + to 0 to see all. Default value is 1000000. :type max_intervals: int. :param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x inf_delta)` above :code:`max_death` value. A reasonable value is @@ -346,22 +321,11 @@ def plot_persistence_diagram( # line display of equation : birth = death axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") - # Draw points in loop - pts_at_infty = False # Records presence of pts at infty - for interval in reversed(persistence): - try: - if float(interval[1][1]) != float("inf"): - # Finite death case - axes.scatter( - interval[1][0], interval[1][1], alpha=alpha, color=colormap[interval[0]], - ) - else: - pts_at_infty = True - # Infinite death case for diagram to be nicer - axes.scatter(interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]) - except IndexError: - pass - if pts_at_infty: + x=[birth for (dim,(birth,death)) in persistence] + y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence] + c=[colormap[dim] for (dim,(birth,death)) in persistence] + axes.scatter(x,y,alpha=alpha,color=c) + if float("inf") in (death for (dim,(birth,death)) in persistence): # infinity line and text axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha) # Infinity label -- cgit v1.2.3