diff options
author | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-06-17 15:43:30 +0200 |
---|---|---|
committer | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-06-17 15:43:30 +0200 |
commit | cb108929433617f553e0a0c7185b3073cce35696 (patch) | |
tree | ead76059c5186a3d09cc160eb881f2ed3b65127a | |
parent | 486b281c726cbb6110cfe3c63b3f225690bcd348 (diff) |
Fix #461 and review all error cases (no more prints, warnings and exceptions instead)
-rw-r--r-- | src/python/CMakeLists.txt | 4 | ||||
-rw-r--r-- | src/python/gudhi/persistence_graphical_tools.py | 245 |
2 files changed, 141 insertions, 108 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 98f2b85f..1f0d74d4 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -542,6 +542,10 @@ if(PYTHONINTERP_FOUND) add_gudhi_py_test(test_dtm_rips_complex) endif() + # persistence graphical tools + if(MATPLOTLIB_FOUND) + add_gudhi_py_test(test_persistence_graphical_tools) + endif() # Set missing or not modules set(GUDHI_MODULES ${GUDHI_MODULES} "python" CACHE INTERNAL "GUDHI_MODULES") diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index b129b375..9bfc19e0 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -13,6 +13,8 @@ from math import isfinite import numpy as np from functools import lru_cache import warnings +import errno +import os from gudhi.reader_utils import read_persistence_intervals_in_dimension from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension @@ -45,6 +47,9 @@ def __min_birth_max_death(persistence, band=0.0): min_birth = float(interval[1][0]) if band > 0.0: max_death += band + # can happen if only points at inf death + if min_birth == max_death: + max_death = max_death + 1. return (min_birth, max_death) @@ -54,7 +59,7 @@ def _array_handler(a): persistence-compatible list (padding with 0), so that the plot can be performed seamlessly. ''' - if isinstance(a[0][1], np.float64) or isinstance(a[0][1], float): + if isinstance(a[0][1], np.floating) or isinstance(a[0][1], float): return [[0, x] for x in a] else: return a @@ -88,7 +93,7 @@ def _matplotlib_can_use_tex(): from matplotlib import checkdep_usetex return checkdep_usetex(True) except ImportError: - print("This function is not available, you may be missing matplotlib.") + warnings.warn("This function is not available, you may be missing matplotlib.") def plot_persistence_barcode( @@ -157,53 +162,58 @@ def plot_persistence_barcode( for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) else: - print("file " + persistence_file + " not found.") - return None + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file) - persistence = _array_handler(persistence) - - persistence = _limit_to_max_intervals(persistence, max_intervals, - key = lambda life_time: life_time[1][1] - life_time[1][0]) - - if colormap == None: - colormap = plt.cm.Set1.colors - if axes == None: - _, axes = plt.subplots(1, 1) - - persistence = sorted(persistence, key=lambda birth: birth[1][0]) + try: + persistence = _array_handler(persistence) + persistence = _limit_to_max_intervals(persistence, max_intervals, + key = lambda life_time: life_time[1][1] - life_time[1][0]) + (min_birth, max_death) = __min_birth_max_death(persistence) + persistence = sorted(persistence, key=lambda birth: birth[1][0]) + except IndexError: + min_birth, max_death = 0., 1. + pass - (min_birth, max_death) = __min_birth_max_death(persistence) - ind = 0 delta = (max_death - min_birth) * inf_delta # Replace infinity values with max_death + delta for bar code to be more # readable infinity = max_death + delta axis_start = min_birth - delta + + if axes == None: + _, axes = plt.subplots(1, 1) + if colormap == None: + colormap = plt.cm.Set1.colors + ind = 0 + # Draw horizontal bars in loop for interval in reversed(persistence): - if float(interval[1][1]) != float("inf"): - # Finite death case - axes.barh( - ind, - (interval[1][1] - interval[1][0]), - height=0.8, - left=interval[1][0], - alpha=alpha, - color=colormap[interval[0]], - linewidth=0, - ) - else: - # Infinite death case for diagram to be nicer - axes.barh( - ind, - (infinity - interval[1][0]), - height=0.8, - left=interval[1][0], - alpha=alpha, - color=colormap[interval[0]], - linewidth=0, - ) - ind = ind + 1 + try: + if float(interval[1][1]) != float("inf"): + # Finite death case + axes.barh( + ind, + (interval[1][1] - interval[1][0]), + height=0.8, + left=interval[1][0], + alpha=alpha, + color=colormap[interval[0]], + linewidth=0, + ) + else: + # Infinite death case for diagram to be nicer + axes.barh( + ind, + (infinity - interval[1][0]), + height=0.8, + left=interval[1][0], + alpha=alpha, + color=colormap[interval[0]], + linewidth=0, + ) + ind = ind + 1 + except IndexError: + pass if legend: dimensions = list(set(item[0] for item in persistence)) @@ -218,11 +228,12 @@ def plot_persistence_barcode( axes.set_title("Persistence barcode", fontsize=fontsize) # Ends plot on infinity value and starts a little bit before min_birth - axes.axis([axis_start, infinity, 0, ind]) + if ind != 0: + axes.axis([axis_start, infinity, 0, ind]) return axes except ImportError: - print("This function is not available, you may be missing matplotlib.") + warnings.warn("This function is not available, you may be missing matplotlib.") def plot_persistence_diagram( @@ -296,20 +307,17 @@ def plot_persistence_diagram( for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) else: - print("file " + persistence_file + " not found.") - return None - - persistence = _array_handler(persistence) - - persistence = _limit_to_max_intervals(persistence, max_intervals, - key = lambda life_time: life_time[1][1] - life_time[1][0]) + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file) - if colormap == None: - colormap = plt.cm.Set1.colors - if axes == None: - _, axes = plt.subplots(1, 1) + try: + persistence = _array_handler(persistence) + persistence = _limit_to_max_intervals(persistence, max_intervals, + key = lambda life_time: life_time[1][1] - life_time[1][0]) + min_birth, max_death = __min_birth_max_death(persistence, band) + except IndexError: + min_birth, max_death = 0., 1. + pass - (min_birth, max_death) = __min_birth_max_death(persistence, band) delta = (max_death - min_birth) * inf_delta # Replace infinity values with max_death + delta for diagram to be more # readable @@ -317,33 +325,43 @@ def plot_persistence_diagram( axis_end = max_death + delta / 2 axis_start = min_birth - delta + if axes == None: + _, axes = plt.subplots(1, 1) + if colormap == None: + colormap = plt.cm.Set1.colors # bootstrap band if band > 0.0: x = np.linspace(axis_start, infinity, 1000) axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red") # lower diag patch if greyblock: - axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey')) + axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], + fill=True, color='lightgrey')) + # line display of equation : birth = death + axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") + # Draw points in loop pts_at_infty = False # Records presence of pts at infty for interval in reversed(persistence): - if float(interval[1][1]) != float("inf"): - # Finite death case - axes.scatter( - interval[1][0], - interval[1][1], - alpha=alpha, - color=colormap[interval[0]], - ) - else: - pts_at_infty = True - # Infinite death case for diagram to be nicer - axes.scatter( - interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]] - ) + try: + if float(interval[1][1]) != float("inf"): + # Finite death case + axes.scatter( + interval[1][0], + interval[1][1], + alpha=alpha, + color=colormap[interval[0]], + ) + else: + pts_at_infty = True + # Infinite death case for diagram to be nicer + axes.scatter( + interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]] + ) + except IndexError: + pass if pts_at_infty: # infinity line and text - axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha) # Infinity label yt = axes.get_yticks() @@ -371,7 +389,7 @@ def plot_persistence_diagram( return axes except ImportError: - print("This function is not available, you may be missing matplotlib.") + warnings.warn("This function is not available, you may be missing matplotlib.") def plot_persistence_density( @@ -461,27 +479,7 @@ def plot_persistence_density( persistence_file=persistence_file, only_this_dim=dimension ) else: - print("file " + persistence_file + " not found.") - return None - - if len(persistence) > 0: - persistence = _array_handler(persistence) - persistence_dim = np.array( - [ - (dim_interval[1][0], dim_interval[1][1]) - for dim_interval in persistence - if (dim_interval[0] == dimension) or (dimension is None) - ] - ) - - persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])] - - persistence_dim = np.array(_limit_to_max_intervals(persistence_dim, max_intervals, - key = lambda life_time: life_time[1] - life_time[0])) - - # Set as numpy array birth and death (remove undefined values - inf and NaN) - birth = persistence_dim[:, 0] - death = persistence_dim[:, 1] + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file) # default cmap value cannot be done at argument definition level as matplotlib is not yet defined. if cmap is None: @@ -489,23 +487,56 @@ def plot_persistence_density( if axes == None: _, axes = plt.subplots(1, 1) + try: + if len(persistence) > 0: + # if not read from file but given by an argument + persistence = _array_handler(persistence) + persistence_dim = np.array( + [ + (dim_interval[1][0], dim_interval[1][1]) + for dim_interval in persistence + if (dim_interval[0] == dimension) or (dimension is None) + ] + ) + persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])] + persistence_dim = np.array(_limit_to_max_intervals(persistence_dim, max_intervals, + key = lambda life_time: life_time[1] - life_time[0])) + + # Set as numpy array birth and death (remove undefined values - inf and NaN) + birth = persistence_dim[:, 0] + death = persistence_dim[:, 1] + birth_min = birth.min() + birth_max = birth.max() + death_min = death.min() + death_max = death.max() + + # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents + k = kde.gaussian_kde([birth, death], bw_method=bw_method) + xi, yi = np.mgrid[ + birth_min : birth_max : nbins * 1j, + death_min : death_max : nbins * 1j, + ] + zi = k(np.vstack([xi.flatten(), yi.flatten()])) + # Make the plot + img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading='auto') + + # IndexError on empty diagrams, ValueError on only inf death values + except (IndexError, ValueError): + birth_min = 0. + birth_max = 1. + death_min = 0. + death_max = 1. + pass + # line display of equation : birth = death - x = np.linspace(death.min(), birth.max(), 1000) + x = np.linspace(death_min, birth_max, 1000) axes.plot(x, x, color="k", linewidth=1.0) - # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents - k = kde.gaussian_kde([birth, death], bw_method=bw_method) - xi, yi = np.mgrid[ - birth.min() : birth.max() : nbins * 1j, - death.min() : death.max() : nbins * 1j, - ] - zi = k(np.vstack([xi.flatten(), yi.flatten()])) - - # Make the plot - img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading='auto') - if greyblock: - axes.add_patch(mpatches.Polygon([[birth.min(), birth.min()], [death.max(), birth.min()], [death.max(), death.max()]], fill=True, color='lightgrey')) + axes.add_patch(mpatches.Polygon([[birth_min, birth_min], + [death_max, birth_min], + [death_max, death_max]], + fill=True, color='lightgrey')) if legend: plt.colorbar(img, ax=axes) @@ -517,6 +548,4 @@ def plot_persistence_density( return axes except ImportError: - print( - "This function is not available, you may be missing matplotlib and/or scipy." - ) + warnings.warn("This function is not available, you may be missing matplotlib and/or scipy.") |