summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-17 15:43:30 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-17 15:43:30 +0200
commitcb108929433617f553e0a0c7185b3073cce35696 (patch)
treeead76059c5186a3d09cc160eb881f2ed3b65127a
parent486b281c726cbb6110cfe3c63b3f225690bcd348 (diff)
Fix #461 and review all error cases (no more prints, warnings and exceptions instead)
-rw-r--r--src/python/CMakeLists.txt4
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py245
2 files changed, 141 insertions, 108 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 98f2b85f..1f0d74d4 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -542,6 +542,10 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_dtm_rips_complex)
endif()
+ # persistence graphical tools
+ if(MATPLOTLIB_FOUND)
+ add_gudhi_py_test(test_persistence_graphical_tools)
+ endif()
# Set missing or not modules
set(GUDHI_MODULES ${GUDHI_MODULES} "python" CACHE INTERNAL "GUDHI_MODULES")
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index b129b375..9bfc19e0 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -13,6 +13,8 @@ from math import isfinite
import numpy as np
from functools import lru_cache
import warnings
+import errno
+import os
from gudhi.reader_utils import read_persistence_intervals_in_dimension
from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension
@@ -45,6 +47,9 @@ def __min_birth_max_death(persistence, band=0.0):
min_birth = float(interval[1][0])
if band > 0.0:
max_death += band
+ # can happen if only points at inf death
+ if min_birth == max_death:
+ max_death = max_death + 1.
return (min_birth, max_death)
@@ -54,7 +59,7 @@ def _array_handler(a):
persistence-compatible list (padding with 0), so that the
plot can be performed seamlessly.
'''
- if isinstance(a[0][1], np.float64) or isinstance(a[0][1], float):
+ if isinstance(a[0][1], np.floating) or isinstance(a[0][1], float):
return [[0, x] for x in a]
else:
return a
@@ -88,7 +93,7 @@ def _matplotlib_can_use_tex():
from matplotlib import checkdep_usetex
return checkdep_usetex(True)
except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ warnings.warn("This function is not available, you may be missing matplotlib.")
def plot_persistence_barcode(
@@ -157,53 +162,58 @@ def plot_persistence_barcode(
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
else:
- print("file " + persistence_file + " not found.")
- return None
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
- persistence = _array_handler(persistence)
-
- persistence = _limit_to_max_intervals(persistence, max_intervals,
- key = lambda life_time: life_time[1][1] - life_time[1][0])
-
- if colormap == None:
- colormap = plt.cm.Set1.colors
- if axes == None:
- _, axes = plt.subplots(1, 1)
-
- persistence = sorted(persistence, key=lambda birth: birth[1][0])
+ try:
+ persistence = _array_handler(persistence)
+ persistence = _limit_to_max_intervals(persistence, max_intervals,
+ key = lambda life_time: life_time[1][1] - life_time[1][0])
+ (min_birth, max_death) = __min_birth_max_death(persistence)
+ persistence = sorted(persistence, key=lambda birth: birth[1][0])
+ except IndexError:
+ min_birth, max_death = 0., 1.
+ pass
- (min_birth, max_death) = __min_birth_max_death(persistence)
- ind = 0
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for bar code to be more
# readable
infinity = max_death + delta
axis_start = min_birth - delta
+
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
+ ind = 0
+
# Draw horizontal bars in loop
for interval in reversed(persistence):
- if float(interval[1][1]) != float("inf"):
- # Finite death case
- axes.barh(
- ind,
- (interval[1][1] - interval[1][0]),
- height=0.8,
- left=interval[1][0],
- alpha=alpha,
- color=colormap[interval[0]],
- linewidth=0,
- )
- else:
- # Infinite death case for diagram to be nicer
- axes.barh(
- ind,
- (infinity - interval[1][0]),
- height=0.8,
- left=interval[1][0],
- alpha=alpha,
- color=colormap[interval[0]],
- linewidth=0,
- )
- ind = ind + 1
+ try:
+ if float(interval[1][1]) != float("inf"):
+ # Finite death case
+ axes.barh(
+ ind,
+ (interval[1][1] - interval[1][0]),
+ height=0.8,
+ left=interval[1][0],
+ alpha=alpha,
+ color=colormap[interval[0]],
+ linewidth=0,
+ )
+ else:
+ # Infinite death case for diagram to be nicer
+ axes.barh(
+ ind,
+ (infinity - interval[1][0]),
+ height=0.8,
+ left=interval[1][0],
+ alpha=alpha,
+ color=colormap[interval[0]],
+ linewidth=0,
+ )
+ ind = ind + 1
+ except IndexError:
+ pass
if legend:
dimensions = list(set(item[0] for item in persistence))
@@ -218,11 +228,12 @@ def plot_persistence_barcode(
axes.set_title("Persistence barcode", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
- axes.axis([axis_start, infinity, 0, ind])
+ if ind != 0:
+ axes.axis([axis_start, infinity, 0, ind])
return axes
except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ warnings.warn("This function is not available, you may be missing matplotlib.")
def plot_persistence_diagram(
@@ -296,20 +307,17 @@ def plot_persistence_diagram(
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
else:
- print("file " + persistence_file + " not found.")
- return None
-
- persistence = _array_handler(persistence)
-
- persistence = _limit_to_max_intervals(persistence, max_intervals,
- key = lambda life_time: life_time[1][1] - life_time[1][0])
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
- if colormap == None:
- colormap = plt.cm.Set1.colors
- if axes == None:
- _, axes = plt.subplots(1, 1)
+ try:
+ persistence = _array_handler(persistence)
+ persistence = _limit_to_max_intervals(persistence, max_intervals,
+ key = lambda life_time: life_time[1][1] - life_time[1][0])
+ min_birth, max_death = __min_birth_max_death(persistence, band)
+ except IndexError:
+ min_birth, max_death = 0., 1.
+ pass
- (min_birth, max_death) = __min_birth_max_death(persistence, band)
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for diagram to be more
# readable
@@ -317,33 +325,43 @@ def plot_persistence_diagram(
axis_end = max_death + delta / 2
axis_start = min_birth - delta
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
# bootstrap band
if band > 0.0:
x = np.linspace(axis_start, infinity, 1000)
axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
# lower diag patch
if greyblock:
- axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey'))
+ axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]],
+ fill=True, color='lightgrey'))
+ # line display of equation : birth = death
+ axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
+
# Draw points in loop
pts_at_infty = False # Records presence of pts at infty
for interval in reversed(persistence):
- if float(interval[1][1]) != float("inf"):
- # Finite death case
- axes.scatter(
- interval[1][0],
- interval[1][1],
- alpha=alpha,
- color=colormap[interval[0]],
- )
- else:
- pts_at_infty = True
- # Infinite death case for diagram to be nicer
- axes.scatter(
- interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]
- )
+ try:
+ if float(interval[1][1]) != float("inf"):
+ # Finite death case
+ axes.scatter(
+ interval[1][0],
+ interval[1][1],
+ alpha=alpha,
+ color=colormap[interval[0]],
+ )
+ else:
+ pts_at_infty = True
+ # Infinite death case for diagram to be nicer
+ axes.scatter(
+ interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]
+ )
+ except IndexError:
+ pass
if pts_at_infty:
# infinity line and text
- axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha)
# Infinity label
yt = axes.get_yticks()
@@ -371,7 +389,7 @@ def plot_persistence_diagram(
return axes
except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ warnings.warn("This function is not available, you may be missing matplotlib.")
def plot_persistence_density(
@@ -461,27 +479,7 @@ def plot_persistence_density(
persistence_file=persistence_file, only_this_dim=dimension
)
else:
- print("file " + persistence_file + " not found.")
- return None
-
- if len(persistence) > 0:
- persistence = _array_handler(persistence)
- persistence_dim = np.array(
- [
- (dim_interval[1][0], dim_interval[1][1])
- for dim_interval in persistence
- if (dim_interval[0] == dimension) or (dimension is None)
- ]
- )
-
- persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
-
- persistence_dim = np.array(_limit_to_max_intervals(persistence_dim, max_intervals,
- key = lambda life_time: life_time[1] - life_time[0]))
-
- # Set as numpy array birth and death (remove undefined values - inf and NaN)
- birth = persistence_dim[:, 0]
- death = persistence_dim[:, 1]
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
# default cmap value cannot be done at argument definition level as matplotlib is not yet defined.
if cmap is None:
@@ -489,23 +487,56 @@ def plot_persistence_density(
if axes == None:
_, axes = plt.subplots(1, 1)
+ try:
+ if len(persistence) > 0:
+ # if not read from file but given by an argument
+ persistence = _array_handler(persistence)
+ persistence_dim = np.array(
+ [
+ (dim_interval[1][0], dim_interval[1][1])
+ for dim_interval in persistence
+ if (dim_interval[0] == dimension) or (dimension is None)
+ ]
+ )
+ persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
+ persistence_dim = np.array(_limit_to_max_intervals(persistence_dim, max_intervals,
+ key = lambda life_time: life_time[1] - life_time[0]))
+
+ # Set as numpy array birth and death (remove undefined values - inf and NaN)
+ birth = persistence_dim[:, 0]
+ death = persistence_dim[:, 1]
+ birth_min = birth.min()
+ birth_max = birth.max()
+ death_min = death.min()
+ death_max = death.max()
+
+ # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
+ k = kde.gaussian_kde([birth, death], bw_method=bw_method)
+ xi, yi = np.mgrid[
+ birth_min : birth_max : nbins * 1j,
+ death_min : death_max : nbins * 1j,
+ ]
+ zi = k(np.vstack([xi.flatten(), yi.flatten()]))
+ # Make the plot
+ img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading='auto')
+
+ # IndexError on empty diagrams, ValueError on only inf death values
+ except (IndexError, ValueError):
+ birth_min = 0.
+ birth_max = 1.
+ death_min = 0.
+ death_max = 1.
+ pass
+
# line display of equation : birth = death
- x = np.linspace(death.min(), birth.max(), 1000)
+ x = np.linspace(death_min, birth_max, 1000)
axes.plot(x, x, color="k", linewidth=1.0)
- # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
- k = kde.gaussian_kde([birth, death], bw_method=bw_method)
- xi, yi = np.mgrid[
- birth.min() : birth.max() : nbins * 1j,
- death.min() : death.max() : nbins * 1j,
- ]
- zi = k(np.vstack([xi.flatten(), yi.flatten()]))
-
- # Make the plot
- img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading='auto')
-
if greyblock:
- axes.add_patch(mpatches.Polygon([[birth.min(), birth.min()], [death.max(), birth.min()], [death.max(), death.max()]], fill=True, color='lightgrey'))
+ axes.add_patch(mpatches.Polygon([[birth_min, birth_min],
+ [death_max, birth_min],
+ [death_max, death_max]],
+ fill=True, color='lightgrey'))
if legend:
plt.colorbar(img, ax=axes)
@@ -517,6 +548,4 @@ def plot_persistence_density(
return axes
except ImportError:
- print(
- "This function is not available, you may be missing matplotlib and/or scipy."
- )
+ warnings.warn("This function is not available, you may be missing matplotlib and/or scipy.")