summaryrefslogtreecommitdiff
path: root/src/python/gudhi/persistence_graphical_tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/persistence_graphical_tools.py')
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py519
1 files changed, 309 insertions, 210 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 181bc8ea..e438aa66 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -1,24 +1,30 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Vincent Rouvreau, Bertrand Michel
+#
+# Copyright (C) 2016 Inria
+#
+# Modification(s):
+# - 2020/02 Theo Lacombe: Added more options for improved rendering and more flexibility.
+# - YYYY/MM Author: Description of the modification
+
from os import path
from math import isfinite
import numpy as np
+from functools import lru_cache
+import warnings
+import errno
+import os
from gudhi.reader_utils import read_persistence_intervals_in_dimension
from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension
-""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
- See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
- Author(s): Vincent Rouvreau, Bertrand Michel
-
- Copyright (C) 2016 Inria
-
- Modification(s):
- - YYYY/MM Author: Description of the modification
-"""
-
-__author__ = "Vincent Rouvreau, Bertrand Michel"
+__author__ = "Vincent Rouvreau, Bertrand Michel, Theo Lacombe"
__copyright__ = "Copyright (C) 2016 Inria"
__license__ = "MIT"
+_gudhi_matplotlib_use_tex = True
+
def __min_birth_max_death(persistence, band=0.0):
"""This function returns (min_birth, max_death) from the persistence.
@@ -42,44 +48,78 @@ def __min_birth_max_death(persistence, band=0.0):
min_birth = float(interval[1][0])
if band > 0.0:
max_death += band
+ # can happen if only points at inf death
+ if min_birth == max_death:
+ max_death = max_death + 1.0
return (min_birth, max_death)
-"""
-Only 13 colors for the palette
-"""
-palette = [
- "#ff0000",
- "#00ff00",
- "#0000ff",
- "#00ffff",
- "#ff00ff",
- "#ffff00",
- "#000000",
- "#880000",
- "#008800",
- "#000088",
- "#888800",
- "#880088",
- "#008888",
-]
+def _array_handler(a):
+ """
+ :param a: if array, assumes it is a (n x 2) np.array and return a
+ persistence-compatible list (padding with 0), so that the
+ plot can be performed seamlessly.
+ """
+ if isinstance(a[0][1], (np.floating, float)):
+ return [[0, x] for x in a]
+ else:
+ return a
+
+
+def _limit_to_max_intervals(persistence, max_intervals, key):
+ """This function returns truncated persistence if length is bigger than max_intervals.
+ :param persistence: Persistence intervals values list. Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
+ :param max_intervals: maximal number of intervals to display.
+ Selected intervals are those with the longest life time. Set it
+ to 0 to see all. Default value is 1000.
+ :type max_intervals: int.
+ :param key: key function for sort algorithm.
+ :type key: function or lambda.
+ """
+ if max_intervals > 0 and max_intervals < len(persistence):
+ warnings.warn(
+ "There are %s intervals given as input, whereas max_intervals is set to %s."
+ % (len(persistence), max_intervals)
+ )
+ # Sort by life time, then takes only the max_intervals elements
+ return sorted(persistence, key=key, reverse=True)[:max_intervals]
+ else:
+ return persistence
+
+
+@lru_cache(maxsize=1)
+def _matplotlib_can_use_tex():
+ """This function returns True if matplotlib can deal with LaTeX, False otherwise.
+ The returned value is cached.
+ """
+ try:
+ from matplotlib import checkdep_usetex
+
+ return checkdep_usetex(True)
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
def plot_persistence_barcode(
persistence=[],
persistence_file="",
alpha=0.6,
- max_intervals=1000,
- max_barcodes=1000,
+ max_intervals=20000,
inf_delta=0.1,
legend=False,
+ colormap=None,
+ axes=None,
+ fontsize=16,
):
"""This function plots the persistence bar code from persistence values list
- or from a :doc:`persistence file <fileformats>`.
+ , a np.array of shape (N x 2) (representing a diagram
+ in a single homology dimension),
+ or from a `persistence diagram <fileformats.html#persistence-diagram>`_ file.
- :param persistence: Persistence intervals values list grouped by dimension.
- :type persistence: list of tuples(dimension, tuple(birth, death)).
- :param persistence_file: A :doc:`persistence file <fileformats>` style name
+ :param persistence: Persistence intervals values list. Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
+ :param persistence_file: A `persistence diagram <fileformats.html#persistence-diagram>`_ file style name
(reset persistence if both are set).
:type persistence_file: string
:param alpha: barcode transparency value (0.0 transparent through 1.0
@@ -87,7 +127,7 @@ def plot_persistence_barcode(
:type alpha: float.
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
- to 0 to see all. Default value is 1000.
+ to 0 to see all. Default value is 20000.
:type max_intervals: int.
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x
inf_delta)` above :code:`max_death` value. A reasonable value is
@@ -95,90 +135,84 @@ def plot_persistence_barcode(
:type inf_delta: float.
:param legend: Display the dimension color legend (default is False).
:type legend: boolean.
- :returns: A matplotlib object containing horizontal bar plot of persistence
- (launch `show()` method on it to display it).
+ :param colormap: A matplotlib-like qualitative colormaps. Default is None
+ which means :code:`matplotlib.cm.Set1.colors`.
+ :type colormap: tuple of colors (3-tuple of float between 0. and 1.).
+ :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
+ a new set of axes.
+ :type axes: `matplotlib.axes.Axes`
+ :param fontsize: Fontsize to use in axis.
+ :type fontsize: int
+ :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
+ from matplotlib import rc
- if persistence_file is not "":
+ if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
+ else:
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
+
+ if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
- diag = read_persistence_intervals_grouped_by_dimension(
- persistence_file=persistence_file
- )
+ diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
for key in diag.keys():
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
else:
- print("file " + persistence_file + " not found.")
- return None
-
- if max_barcodes is not 1000:
- print("Deprecated parameter. It has been replaced by max_intervals")
- max_intervals = max_barcodes
-
- if max_intervals > 0 and max_intervals < len(persistence):
- # Sort by life time, then takes only the max_intervals elements
- persistence = sorted(
- persistence,
- key=lambda life_time: life_time[1][1] - life_time[1][0],
- reverse=True,
- )[:max_intervals]
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
- persistence = sorted(persistence, key=lambda birth: birth[1][0])
+ try:
+ persistence = _array_handler(persistence)
+ persistence = _limit_to_max_intervals(
+ persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ (min_birth, max_death) = __min_birth_max_death(persistence)
+ persistence = sorted(persistence, key=lambda birth: birth[1][0])
+ except IndexError:
+ min_birth, max_death = 0.0, 1.0
+ pass
- (min_birth, max_death) = __min_birth_max_death(persistence)
- ind = 0
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for bar code to be more
# readable
infinity = max_death + delta
axis_start = min_birth - delta
- # Draw horizontal bars in loop
- for interval in reversed(persistence):
- if float(interval[1][1]) != float("inf"):
- # Finite death case
- plt.barh(
- ind,
- (interval[1][1] - interval[1][0]),
- height=0.8,
- left=interval[1][0],
- alpha=alpha,
- color=palette[interval[0]],
- linewidth=0,
- )
- else:
- # Infinite death case for diagram to be nicer
- plt.barh(
- ind,
- (infinity - interval[1][0]),
- height=0.8,
- left=interval[1][0],
- alpha=alpha,
- color=palette[interval[0]],
- linewidth=0,
- )
- ind = ind + 1
+
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
+
+ x=[birth for (dim,(birth,death)) in persistence]
+ y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence]
+ c=[colormap[dim] for (dim,(birth,death)) in persistence]
+
+ axes.barh(range(len(x)), y, left=x, alpha=alpha, color=c, linewidth=0)
if legend:
- dimensions = list(set(item[0] for item in persistence))
- plt.legend(
- handles=[
- mpatches.Patch(color=palette[dim], label=str(dim))
- for dim in dimensions
- ],
- loc="lower right",
+ dimensions = set(item[0] for item in persistence)
+ axes.legend(
+ handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right",
)
- plt.title("Persistence barcode")
+
+ axes.set_title("Persistence barcode", fontsize=fontsize)
+ axes.set_yticks([])
+ axes.invert_yaxis()
+
# Ends plot on infinity value and starts a little bit before min_birth
- plt.axis([axis_start, infinity, 0, ind])
- return plt
+ if len(x) != 0:
+ axes.set_xlim((axis_start, infinity))
+ return axes
- except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
def plot_persistence_diagram(
@@ -186,17 +220,21 @@ def plot_persistence_diagram(
persistence_file="",
alpha=0.6,
band=0.0,
- max_intervals=1000,
- max_plots=1000,
+ max_intervals=1000000,
inf_delta=0.1,
legend=False,
+ colormap=None,
+ axes=None,
+ fontsize=16,
+ greyblock=True,
):
"""This function plots the persistence diagram from persistence values
- list or from a :doc:`persistence file <fileformats>`.
+ list, a np.array of shape (N x 2) representing a diagram in a single
+ homology dimension, or from a `persistence diagram <fileformats.html#persistence-diagram>`_ file`.
- :param persistence: Persistence intervals values list grouped by dimension.
- :type persistence: list of tuples(dimension, tuple(birth, death)).
- :param persistence_file: A :doc:`persistence file <fileformats>` style name
+ :param persistence: Persistence intervals values list. Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
+ :param persistence_file: A `persistence diagram <fileformats.html#persistence-diagram>`_ file style name
(reset persistence if both are set).
:type persistence_file: string
:param alpha: plot transparency value (0.0 transparent through 1.0
@@ -206,7 +244,7 @@ def plot_persistence_diagram(
:type band: float.
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
- to 0 to see all. Default value is 1000.
+ to 0 to see all. Default value is 1000000.
:type max_intervals: int.
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x
inf_delta)` above :code:`max_death` value. A reasonable value is
@@ -214,90 +252,108 @@ def plot_persistence_diagram(
:type inf_delta: float.
:param legend: Display the dimension color legend (default is False).
:type legend: boolean.
- :returns: A matplotlib object containing diagram plot of persistence
- (launch `show()` method on it to display it).
+ :param colormap: A matplotlib-like qualitative colormaps. Default is None
+ which means :code:`matplotlib.cm.Set1.colors`.
+ :type colormap: tuple of colors (3-tuple of float between 0. and 1.).
+ :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
+ a new set of axes.
+ :type axes: `matplotlib.axes.Axes`
+ :param fontsize: Fontsize to use in axis.
+ :type fontsize: int
+ :param greyblock: if we want to plot a grey patch on the lower half plane for nicer rendering. Default True.
+ :type greyblock: boolean
+ :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
+ from matplotlib import rc
- if persistence_file is not "":
+ if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
+ else:
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
+
+ if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
- diag = read_persistence_intervals_grouped_by_dimension(
- persistence_file=persistence_file
- )
+ diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
for key in diag.keys():
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
else:
- print("file " + persistence_file + " not found.")
- return None
-
- if max_plots is not 1000:
- print("Deprecated parameter. It has been replaced by max_intervals")
- max_intervals = max_plots
-
- if max_intervals > 0 and max_intervals < len(persistence):
- # Sort by life time, then takes only the max_intervals elements
- persistence = sorted(
- persistence,
- key=lambda life_time: life_time[1][1] - life_time[1][0],
- reverse=True,
- )[:max_intervals]
-
- (min_birth, max_death) = __min_birth_max_death(persistence, band)
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
+
+ try:
+ persistence = _array_handler(persistence)
+ persistence = _limit_to_max_intervals(
+ persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ min_birth, max_death = __min_birth_max_death(persistence, band)
+ except IndexError:
+ min_birth, max_death = 0.0, 1.0
+ pass
+
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for diagram to be more
# readable
infinity = max_death + delta
+ axis_end = max_death + delta / 2
axis_start = min_birth - delta
- # line display of equation : birth = death
- x = np.linspace(axis_start, infinity, 1000)
- # infinity line and text
- plt.plot(x, x, color="k", linewidth=1.0)
- plt.plot(x, [infinity] * len(x), linewidth=1.0, color="k", alpha=alpha)
- plt.text(axis_start, infinity, r"$\infty$", color="k", alpha=alpha)
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
# bootstrap band
if band > 0.0:
- plt.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
-
- # Draw points in loop
- for interval in reversed(persistence):
- if float(interval[1][1]) != float("inf"):
- # Finite death case
- plt.scatter(
- interval[1][0],
- interval[1][1],
- alpha=alpha,
- color=palette[interval[0]],
- )
- else:
- # Infinite death case for diagram to be nicer
- plt.scatter(
- interval[1][0], infinity, alpha=alpha, color=palette[interval[0]]
+ x = np.linspace(axis_start, infinity, 1000)
+ axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
+ # lower diag patch
+ if greyblock:
+ axes.add_patch(
+ mpatches.Polygon(
+ [[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]],
+ fill=True,
+ color="lightgrey",
)
+ )
+ # line display of equation : birth = death
+ axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
+
+ x=[birth for (dim,(birth,death)) in persistence]
+ y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence]
+ c=[colormap[dim] for (dim,(birth,death)) in persistence]
+
+ axes.scatter(x,y,alpha=alpha,color=c)
+ if float("inf") in (death for (dim,(birth,death)) in persistence):
+ # infinity line and text
+ axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha)
+ # Infinity label
+ yt = axes.get_yticks()
+ yt = yt[np.where(yt < axis_end)] # to avoid plotting ticklabel higher than infinity
+ yt = np.append(yt, infinity)
+ ytl = ["%.3f" % e for e in yt] # to avoid float precision error
+ ytl[-1] = r"$+\infty$"
+ axes.set_yticks(yt)
+ axes.set_yticklabels(ytl)
if legend:
dimensions = list(set(item[0] for item in persistence))
- plt.legend(
- handles=[
- mpatches.Patch(color=palette[dim], label=str(dim))
- for dim in dimensions
- ]
- )
+ axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions])
- plt.title("Persistence diagram")
- plt.xlabel("Birth")
- plt.ylabel("Death")
+ axes.set_xlabel("Birth", fontsize=fontsize)
+ axes.set_ylabel("Death", fontsize=fontsize)
+ axes.set_title("Persistence diagram", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
- plt.axis([axis_start, infinity, axis_start, infinity + delta])
- return plt
+ axes.axis([axis_start, axis_end, axis_start, infinity + delta / 2])
+ return axes
- except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
def plot_persistence_density(
@@ -309,17 +365,26 @@ def plot_persistence_density(
dimension=None,
cmap=None,
legend=False,
+ axes=None,
+ fontsize=16,
+ greyblock=False,
):
"""This function plots the persistence density from persistence
- values list or from a :doc:`persistence file <fileformats>`. Be
- aware that this function does not distinguish the dimension, it is
+ values list, np.array of shape (N x 2) representing a diagram
+ in a single homology dimension,
+ or from a `persistence diagram <fileformats.html#persistence-diagram>`_ file.
+ Be aware that this function does not distinguish the dimension, it is
up to you to select the required one. This function also does not handle
degenerate data set (scipy correlation matrix inversion can fail).
- :param persistence: Persistence intervals values list grouped by dimension.
- :type persistence: list of tuples(dimension, tuple(birth, death)).
- :param persistence_file: A :doc:`persistence file <fileformats>`
- style name (reset persistence if both are set).
+ :Requires: `SciPy <installation.html#scipy>`_
+
+ :param persistence: Persistence intervals values list.
+ Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death))
+ or an array of (birth, death).
+ :param persistence_file: A `persistence diagram <fileformats.html#persistence-diagram>`_
+ file style name (reset persistence if both are set).
:type persistence_file: string
:param nbins: Evaluate a gaussian kde on a regular grid of nbins x
nbins over data extents (default is 300)
@@ -347,14 +412,30 @@ def plot_persistence_density(
:type cmap: cf. matplotlib colormap.
:param legend: Display the color bar values (default is False).
:type legend: boolean.
- :returns: A matplotlib object containing diagram plot of persistence
- (launch `show()` method on it to display it).
+ :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
+ a new set of axes.
+ :type axes: `matplotlib.axes.Axes`
+ :param fontsize: Fontsize to use in axis.
+ :type fontsize: int
+ :param greyblock: if we want to plot a grey patch on the lower half plane
+ for nicer rendering. Default False.
+ :type greyblock: boolean
+ :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
+ import matplotlib.patches as mpatches
from scipy.stats import kde
+ from matplotlib import rc
+
+ if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
+ else:
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
- if persistence_file is not "":
+ if persistence_file != "":
if dimension is None:
# All dimension case
dimension = -1
@@ -362,12 +443,18 @@ def plot_persistence_density(
persistence_dim = read_persistence_intervals_in_dimension(
persistence_file=persistence_file, only_this_dim=dimension
)
- print(persistence_dim)
else:
- print("file " + persistence_file + " not found.")
- return None
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
- if len(persistence) > 0:
+ # default cmap value cannot be done at argument definition level as matplotlib is not yet defined.
+ if cmap is None:
+ cmap = plt.cm.hot_r
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
+
+ try:
+ # if not read from file but given by an argument
+ persistence = _array_handler(persistence)
persistence_dim = np.array(
[
(dim_interval[1][0], dim_interval[1][1])
@@ -375,49 +462,61 @@ def plot_persistence_density(
if (dim_interval[0] == dimension) or (dimension is None)
]
)
-
- persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
- if max_intervals > 0 and max_intervals < len(persistence_dim):
- # Sort by life time, then takes only the max_intervals elements
+ persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
persistence_dim = np.array(
- sorted(
- persistence_dim,
- key=lambda life_time: life_time[1] - life_time[0],
- reverse=True,
- )[:max_intervals]
+ _limit_to_max_intervals(
+ persistence_dim, max_intervals, key=lambda life_time: life_time[1] - life_time[0]
+ )
)
- # Set as numpy array birth and death (remove undefined values - inf and NaN)
- birth = persistence_dim[:, 0]
- death = persistence_dim[:, 1]
+ # Set as numpy array birth and death (remove undefined values - inf and NaN)
+ birth = persistence_dim[:, 0]
+ death = persistence_dim[:, 1]
+ birth_min = birth.min()
+ birth_max = birth.max()
+ death_min = death.min()
+ death_max = death.max()
+
+ # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
+ k = kde.gaussian_kde([birth, death], bw_method=bw_method)
+ xi, yi = np.mgrid[
+ birth_min : birth_max : nbins * 1j, death_min : death_max : nbins * 1j,
+ ]
+ zi = k(np.vstack([xi.flatten(), yi.flatten()]))
+ # Make the plot
+ img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto")
+ plot_success = True
+
+ # IndexError on empty diagrams, ValueError on only inf death values
+ except (IndexError, ValueError):
+ birth_min = 0.0
+ birth_max = 1.0
+ death_min = 0.0
+ death_max = 1.0
+ plot_success = False
+ pass
# line display of equation : birth = death
- x = np.linspace(death.min(), birth.max(), 1000)
- plt.plot(x, x, color="k", linewidth=1.0)
-
- # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
- k = kde.gaussian_kde([birth, death], bw_method=bw_method)
- xi, yi = np.mgrid[
- birth.min() : birth.max() : nbins * 1j,
- death.min() : death.max() : nbins * 1j,
- ]
- zi = k(np.vstack([xi.flatten(), yi.flatten()]))
+ x = np.linspace(death_min, birth_max, 1000)
+ axes.plot(x, x, color="k", linewidth=1.0)
+
+ if greyblock:
+ axes.add_patch(
+ mpatches.Polygon(
+ [[birth_min, birth_min], [death_max, birth_min], [death_max, death_max]],
+ fill=True,
+ color="lightgrey",
+ )
+ )
- # default cmap value cannot be done at argument definition level as matplotlib is not yet defined.
- if cmap is None:
- cmap = plt.cm.hot_r
- # Make the plot
- plt.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap)
+ if plot_success and legend:
+ plt.colorbar(img, ax=axes)
- if legend:
- plt.colorbar()
+ axes.set_xlabel("Birth", fontsize=fontsize)
+ axes.set_ylabel("Death", fontsize=fontsize)
+ axes.set_title("Persistence density", fontsize=fontsize)
- plt.title("Persistence density")
- plt.xlabel("Birth")
- plt.ylabel("Death")
- return plt
+ return axes
- except ImportError:
- print(
- "This function is not available, you may be missing matplotlib and/or scipy."
- )
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")