From 3952b61604db976f4864b85ebdafa0962766b8bc Mon Sep 17 00:00:00 2001 From: Vincent Rouvreau Date: Tue, 10 May 2022 10:00:02 +0200 Subject: Fix limit tests for plot (and warning in test) --- src/python/gudhi/persistence_graphical_tools.py | 32 +++++++++++++++------- .../test/test_persistence_graphical_tools.py | 6 +++- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 604018d1..930df825 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -190,12 +190,17 @@ def plot_persistence_barcode( if colormap == None: colormap = plt.cm.Set1.colors - x=[birth for (dim,(birth,death)) in persistence] - y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence] - c=[colormap[dim] for (dim,(birth,death)) in persistence] + non_empty_diagram = len(persistence[0]) > 0 + if non_empty_diagram: + x=[birth for (dim,(birth,death)) in persistence] + y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence] + c=[colormap[dim] for (dim,(birth,death)) in persistence] + else: + x, y, c = [], [], [] + axes.barh(list(reversed(range(len(x)))), y, height=0.8, left=x, alpha=alpha, color=c, linewidth=0) - if legend: + if non_empty_diagram and legend: dimensions = list(set(item[0] for item in persistence)) axes.legend( handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right", @@ -321,11 +326,16 @@ def plot_persistence_diagram( # line display of equation : birth = death axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") - x=[birth for (dim,(birth,death)) in persistence] - y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence] - c=[colormap[dim] for (dim,(birth,death)) in persistence] + non_empty_diagram = len(persistence[0]) > 0 + if non_empty_diagram: + x=[birth for (dim,(birth,death)) in persistence] + y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence] + c=[colormap[dim] for (dim,(birth,death)) in persistence] + else: + x, y, c = [], [], [] + axes.scatter(x,y,alpha=alpha,color=c) - if float("inf") in (death for (dim,(birth,death)) in persistence): + if non_empty_diagram and float("inf") in (death for (dim,(birth,death)) in persistence): # infinity line and text axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha) # Infinity label @@ -337,7 +347,7 @@ def plot_persistence_diagram( axes.set_yticks(yt) axes.set_yticklabels(ytl) - if legend: + if non_empty_diagram and legend: dimensions = list(set(item[0] for item in persistence)) axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions]) @@ -482,6 +492,7 @@ def plot_persistence_density( zi = k(np.vstack([xi.flatten(), yi.flatten()])) # Make the plot img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto") + non_empty_diagram = True # IndexError on empty diagrams, ValueError on only inf death values except (IndexError, ValueError): @@ -489,6 +500,7 @@ def plot_persistence_density( birth_max = 1.0 death_min = 0.0 death_max = 1.0 + non_empty_diagram = False pass # line display of equation : birth = death @@ -504,7 +516,7 @@ def plot_persistence_density( ) ) - if legend: + if non_empty_diagram and legend: plt.colorbar(img, ax=axes) axes.set_xlabel("Birth", fontsize=fontsize) diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py index 7d9bae90..994791d7 100644 --- a/src/python/test/test_persistence_graphical_tools.py +++ b/src/python/test/test_persistence_graphical_tools.py @@ -15,7 +15,7 @@ import pytest def test_array_handler(): - diags = np.array([[1, 2], [3, 4], [5, 6]], np.float) + diags = np.array([[1, 2], [3, 4], [5, 6]], float) arr_diags = gd.persistence_graphical_tools._array_handler(diags) for idx in range(len(diags)): assert arr_diags[idx][0] == 0 @@ -98,8 +98,12 @@ def test_limit_to_max_intervals(): def _limit_plot_persistence(function): pplot = function(persistence=[()]) assert issubclass(type(pplot), plt.axes.SubplotBase) + pplot = function(persistence=[()], legend=True) + assert issubclass(type(pplot), plt.axes.SubplotBase) pplot = function(persistence=[(0, float("inf"))]) assert issubclass(type(pplot), plt.axes.SubplotBase) + pplot = function(persistence=[(0, float("inf"))], legend=True) + assert issubclass(type(pplot), plt.axes.SubplotBase) def test_limit_plot_persistence(): -- cgit v1.2.3