summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py25
-rw-r--r--src/python/test/test_persistence_graphical_tools.py12
2 files changed, 22 insertions, 15 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 604018d1..7ed11360 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -193,6 +193,7 @@ def plot_persistence_barcode(
x=[birth for (dim,(birth,death)) in persistence]
y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence]
c=[colormap[dim] for (dim,(birth,death)) in persistence]
+
axes.barh(list(reversed(range(len(x)))), y, height=0.8, left=x, alpha=alpha, color=c, linewidth=0)
if legend:
@@ -324,6 +325,7 @@ def plot_persistence_diagram(
x=[birth for (dim,(birth,death)) in persistence]
y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence]
c=[colormap[dim] for (dim,(birth,death)) in persistence]
+
axes.scatter(x,y,alpha=alpha,color=c)
if float("inf") in (death for (dim,(birth,death)) in persistence):
# infinity line and text
@@ -449,16 +451,15 @@ def plot_persistence_density(
_, axes = plt.subplots(1, 1)
try:
- if len(persistence) > 0:
- # if not read from file but given by an argument
- persistence = _array_handler(persistence)
- persistence_dim = np.array(
- [
- (dim_interval[1][0], dim_interval[1][1])
- for dim_interval in persistence
- if (dim_interval[0] == dimension) or (dimension is None)
- ]
- )
+ # if not read from file but given by an argument
+ persistence = _array_handler(persistence)
+ persistence_dim = np.array(
+ [
+ (dim_interval[1][0], dim_interval[1][1])
+ for dim_interval in persistence
+ if (dim_interval[0] == dimension) or (dimension is None)
+ ]
+ )
persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
persistence_dim = np.array(
_limit_to_max_intervals(
@@ -482,6 +483,7 @@ def plot_persistence_density(
zi = k(np.vstack([xi.flatten(), yi.flatten()]))
# Make the plot
img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto")
+ plot_success = True
# IndexError on empty diagrams, ValueError on only inf death values
except (IndexError, ValueError):
@@ -489,6 +491,7 @@ def plot_persistence_density(
birth_max = 1.0
death_min = 0.0
death_max = 1.0
+ plot_success = False
pass
# line display of equation : birth = death
@@ -504,7 +507,7 @@ def plot_persistence_density(
)
)
- if legend:
+ if plot_success and legend:
plt.colorbar(img, ax=axes)
axes.set_xlabel("Birth", fontsize=fontsize)
diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py
index 7d9bae90..c19836b7 100644
--- a/src/python/test/test_persistence_graphical_tools.py
+++ b/src/python/test/test_persistence_graphical_tools.py
@@ -15,7 +15,7 @@ import pytest
def test_array_handler():
- diags = np.array([[1, 2], [3, 4], [5, 6]], np.float)
+ diags = np.array([[1, 2], [3, 4], [5, 6]], float)
arr_diags = gd.persistence_graphical_tools._array_handler(diags)
for idx in range(len(diags)):
assert arr_diags[idx][0] == 0
@@ -96,10 +96,14 @@ def test_limit_to_max_intervals():
def _limit_plot_persistence(function):
- pplot = function(persistence=[()])
- assert issubclass(type(pplot), plt.axes.SubplotBase)
+ pplot = function(persistence=[])
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[], legend=True)
+ assert isinstance(pplot, plt.axes.SubplotBase)
pplot = function(persistence=[(0, float("inf"))])
- assert issubclass(type(pplot), plt.axes.SubplotBase)
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[(0, float("inf"))], legend=True)
+ assert isinstance(pplot, plt.axes.SubplotBase)
def test_limit_plot_persistence():