summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVincent Rouvreau <vincent.rouvreau@inria.fr>2022-05-10 10:00:02 +0200
committerVincent Rouvreau <vincent.rouvreau@inria.fr>2022-05-10 10:00:02 +0200
commit3952b61604db976f4864b85ebdafa0962766b8bc (patch)
tree28b1dc4406f79536c795defcd6d145fd543160e7
parent7491d4b98d5629d4efd53253188698dae1d845d0 (diff)
Fix limit tests for plot (and warning in test)
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py32
-rw-r--r--src/python/test/test_persistence_graphical_tools.py6
2 files changed, 27 insertions, 11 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 604018d1..930df825 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -190,12 +190,17 @@ def plot_persistence_barcode(
if colormap == None:
colormap = plt.cm.Set1.colors
- x=[birth for (dim,(birth,death)) in persistence]
- y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence]
- c=[colormap[dim] for (dim,(birth,death)) in persistence]
+ non_empty_diagram = len(persistence[0]) > 0
+ if non_empty_diagram:
+ x=[birth for (dim,(birth,death)) in persistence]
+ y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence]
+ c=[colormap[dim] for (dim,(birth,death)) in persistence]
+ else:
+ x, y, c = [], [], []
+
axes.barh(list(reversed(range(len(x)))), y, height=0.8, left=x, alpha=alpha, color=c, linewidth=0)
- if legend:
+ if non_empty_diagram and legend:
dimensions = list(set(item[0] for item in persistence))
axes.legend(
handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right",
@@ -321,11 +326,16 @@ def plot_persistence_diagram(
# line display of equation : birth = death
axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
- x=[birth for (dim,(birth,death)) in persistence]
- y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence]
- c=[colormap[dim] for (dim,(birth,death)) in persistence]
+ non_empty_diagram = len(persistence[0]) > 0
+ if non_empty_diagram:
+ x=[birth for (dim,(birth,death)) in persistence]
+ y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence]
+ c=[colormap[dim] for (dim,(birth,death)) in persistence]
+ else:
+ x, y, c = [], [], []
+
axes.scatter(x,y,alpha=alpha,color=c)
- if float("inf") in (death for (dim,(birth,death)) in persistence):
+ if non_empty_diagram and float("inf") in (death for (dim,(birth,death)) in persistence):
# infinity line and text
axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha)
# Infinity label
@@ -337,7 +347,7 @@ def plot_persistence_diagram(
axes.set_yticks(yt)
axes.set_yticklabels(ytl)
- if legend:
+ if non_empty_diagram and legend:
dimensions = list(set(item[0] for item in persistence))
axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions])
@@ -482,6 +492,7 @@ def plot_persistence_density(
zi = k(np.vstack([xi.flatten(), yi.flatten()]))
# Make the plot
img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto")
+ non_empty_diagram = True
# IndexError on empty diagrams, ValueError on only inf death values
except (IndexError, ValueError):
@@ -489,6 +500,7 @@ def plot_persistence_density(
birth_max = 1.0
death_min = 0.0
death_max = 1.0
+ non_empty_diagram = False
pass
# line display of equation : birth = death
@@ -504,7 +516,7 @@ def plot_persistence_density(
)
)
- if legend:
+ if non_empty_diagram and legend:
plt.colorbar(img, ax=axes)
axes.set_xlabel("Birth", fontsize=fontsize)
diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py
index 7d9bae90..994791d7 100644
--- a/src/python/test/test_persistence_graphical_tools.py
+++ b/src/python/test/test_persistence_graphical_tools.py
@@ -15,7 +15,7 @@ import pytest
def test_array_handler():
- diags = np.array([[1, 2], [3, 4], [5, 6]], np.float)
+ diags = np.array([[1, 2], [3, 4], [5, 6]], float)
arr_diags = gd.persistence_graphical_tools._array_handler(diags)
for idx in range(len(diags)):
assert arr_diags[idx][0] == 0
@@ -98,8 +98,12 @@ def test_limit_to_max_intervals():
def _limit_plot_persistence(function):
pplot = function(persistence=[()])
assert issubclass(type(pplot), plt.axes.SubplotBase)
+ pplot = function(persistence=[()], legend=True)
+ assert issubclass(type(pplot), plt.axes.SubplotBase)
pplot = function(persistence=[(0, float("inf"))])
assert issubclass(type(pplot), plt.axes.SubplotBase)
+ pplot = function(persistence=[(0, float("inf"))], legend=True)
+ assert issubclass(type(pplot), plt.axes.SubplotBase)
def test_limit_plot_persistence():