diff options
-rw-r--r-- | src/python/gudhi/persistence_graphical_tools.py | 25 | ||||
-rw-r--r-- | src/python/test/test_persistence_graphical_tools.py | 12 |
2 files changed, 22 insertions, 15 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 604018d1..7ed11360 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -193,6 +193,7 @@ def plot_persistence_barcode( x=[birth for (dim,(birth,death)) in persistence] y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence] c=[colormap[dim] for (dim,(birth,death)) in persistence] + axes.barh(list(reversed(range(len(x)))), y, height=0.8, left=x, alpha=alpha, color=c, linewidth=0) if legend: @@ -324,6 +325,7 @@ def plot_persistence_diagram( x=[birth for (dim,(birth,death)) in persistence] y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence] c=[colormap[dim] for (dim,(birth,death)) in persistence] + axes.scatter(x,y,alpha=alpha,color=c) if float("inf") in (death for (dim,(birth,death)) in persistence): # infinity line and text @@ -449,16 +451,15 @@ def plot_persistence_density( _, axes = plt.subplots(1, 1) try: - if len(persistence) > 0: - # if not read from file but given by an argument - persistence = _array_handler(persistence) - persistence_dim = np.array( - [ - (dim_interval[1][0], dim_interval[1][1]) - for dim_interval in persistence - if (dim_interval[0] == dimension) or (dimension is None) - ] - ) + # if not read from file but given by an argument + persistence = _array_handler(persistence) + persistence_dim = np.array( + [ + (dim_interval[1][0], dim_interval[1][1]) + for dim_interval in persistence + if (dim_interval[0] == dimension) or (dimension is None) + ] + ) persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])] persistence_dim = np.array( _limit_to_max_intervals( @@ -482,6 +483,7 @@ def plot_persistence_density( zi = k(np.vstack([xi.flatten(), yi.flatten()])) # Make the plot img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto") + plot_success = True # IndexError on empty diagrams, ValueError on only inf death values except (IndexError, ValueError): @@ -489,6 +491,7 @@ def plot_persistence_density( birth_max = 1.0 death_min = 0.0 death_max = 1.0 + plot_success = False pass # line display of equation : birth = death @@ -504,7 +507,7 @@ def plot_persistence_density( ) ) - if legend: + if plot_success and legend: plt.colorbar(img, ax=axes) axes.set_xlabel("Birth", fontsize=fontsize) diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py index 7d9bae90..c19836b7 100644 --- a/src/python/test/test_persistence_graphical_tools.py +++ b/src/python/test/test_persistence_graphical_tools.py @@ -15,7 +15,7 @@ import pytest def test_array_handler(): - diags = np.array([[1, 2], [3, 4], [5, 6]], np.float) + diags = np.array([[1, 2], [3, 4], [5, 6]], float) arr_diags = gd.persistence_graphical_tools._array_handler(diags) for idx in range(len(diags)): assert arr_diags[idx][0] == 0 @@ -96,10 +96,14 @@ def test_limit_to_max_intervals(): def _limit_plot_persistence(function): - pplot = function(persistence=[()]) - assert issubclass(type(pplot), plt.axes.SubplotBase) + pplot = function(persistence=[]) + assert isinstance(pplot, plt.axes.SubplotBase) + pplot = function(persistence=[], legend=True) + assert isinstance(pplot, plt.axes.SubplotBase) pplot = function(persistence=[(0, float("inf"))]) - assert issubclass(type(pplot), plt.axes.SubplotBase) + assert isinstance(pplot, plt.axes.SubplotBase) + pplot = function(persistence=[(0, float("inf"))], legend=True) + assert isinstance(pplot, plt.axes.SubplotBase) def test_limit_plot_persistence(): |