summaryrefslogtreecommitdiff
path: root/src/python/gudhi/persistence_graphical_tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/persistence_graphical_tools.py')
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py136
1 files changed, 103 insertions, 33 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 246280de..6a74a6ca 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -5,6 +5,7 @@
# Copyright (C) 2016 Inria
#
# Modification(s):
+# - 2020/02 Theo Lacombe: Added more options for improved rendering and more flexibility.
# - YYYY/MM Author: Description of the modification
from os import path
@@ -14,7 +15,7 @@ import numpy as np
from gudhi.reader_utils import read_persistence_intervals_in_dimension
from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension
-__author__ = "Vincent Rouvreau, Bertrand Michel"
+__author__ = "Vincent Rouvreau, Bertrand Michel, Theo Lacombe"
__copyright__ = "Copyright (C) 2016 Inria"
__license__ = "MIT"
@@ -43,6 +44,19 @@ def __min_birth_max_death(persistence, band=0.0):
max_death += band
return (min_birth, max_death)
+
+def _array_handler(a):
+ '''
+ :param a: if array, assumes it is a (n x 2) np.array and return a
+ persistence-compatible list (padding with 0), so that the
+ plot can be performed seamlessly.
+ '''
+ if isinstance(a[0][1], np.float64) or isinstance(a[0][1], float):
+ return [[0, x] for x in a]
+ else:
+ return a
+
+
def plot_persistence_barcode(
persistence=[],
persistence_file="",
@@ -52,14 +66,17 @@ def plot_persistence_barcode(
inf_delta=0.1,
legend=False,
colormap=None,
- axes=None
+ axes=None,
+ fontsize=16,
):
"""This function plots the persistence bar code from persistence values list
- or from a :doc:`persistence file <fileformats>`.
+ , a np.array of shape (N x 2) (representing a diagram
+ in a single homology dimension),
+ or from a `persistence diagram <fileformats.html#persistence-diagram>`_ file.
- :param persistence: Persistence intervals values list grouped by dimension.
- :type persistence: list of tuples(dimension, tuple(birth, death)).
- :param persistence_file: A :doc:`persistence file <fileformats>` style name
+ :param persistence: Persistence intervals values list. Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
+ :param persistence_file: A `persistence diagram <fileformats.html#persistence-diagram>`_ file style name
(reset persistence if both are set).
:type persistence_file: string
:param alpha: barcode transparency value (0.0 transparent through 1.0
@@ -81,11 +98,16 @@ def plot_persistence_barcode(
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:type axes: `matplotlib.axes.Axes`
+ :param fontsize: Fontsize to use in axis.
+ :type fontsize: int
:returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
+ from matplotlib import rc
+ plt.rc('text', usetex=True)
+ plt.rc('font', family='serif')
if persistence_file != "":
if path.isfile(persistence_file):
@@ -101,6 +123,8 @@ def plot_persistence_barcode(
print("file " + persistence_file + " not found.")
return None
+ persistence = _array_handler(persistence)
+
if max_barcodes != 1000:
print("Deprecated parameter. It has been replaced by max_intervals")
max_intervals = max_barcodes
@@ -163,7 +187,7 @@ def plot_persistence_barcode(
loc="lower right",
)
- axes.set_title("Persistence barcode")
+ axes.set_title("Persistence barcode", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
axes.axis([axis_start, infinity, 0, ind])
@@ -183,14 +207,17 @@ def plot_persistence_diagram(
inf_delta=0.1,
legend=False,
colormap=None,
- axes=None
+ axes=None,
+ fontsize=16,
+ greyblock=True
):
"""This function plots the persistence diagram from persistence values
- list or from a :doc:`persistence file <fileformats>`.
+ list, a np.array of shape (N x 2) representing a diagram in a single
+ homology dimension, or from a `persistence diagram <fileformats.html#persistence-diagram>`_ file`.
- :param persistence: Persistence intervals values list grouped by dimension.
- :type persistence: list of tuples(dimension, tuple(birth, death)).
- :param persistence_file: A :doc:`persistence file <fileformats>` style name
+ :param persistence: Persistence intervals values list. Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
+ :param persistence_file: A `persistence diagram <fileformats.html#persistence-diagram>`_ file style name
(reset persistence if both are set).
:type persistence_file: string
:param alpha: plot transparency value (0.0 transparent through 1.0
@@ -214,11 +241,18 @@ def plot_persistence_diagram(
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:type axes: `matplotlib.axes.Axes`
+ :param fontsize: Fontsize to use in axis.
+ :type fontsize: int
+ :param greyblock: if we want to plot a grey patch on the lower half plane for nicer rendering. Default True.
+ :type greyblock: boolean
:returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
+ from matplotlib import rc
+ plt.rc('text', usetex=True)
+ plt.rc('font', family='serif')
if persistence_file != "":
if path.isfile(persistence_file):
@@ -234,6 +268,8 @@ def plot_persistence_diagram(
print("file " + persistence_file + " not found.")
return None
+ persistence = _array_handler(persistence)
+
if max_plots != 1000:
print("Deprecated parameter. It has been replaced by max_intervals")
max_intervals = max_plots
@@ -256,19 +292,18 @@ def plot_persistence_diagram(
# Replace infinity values with max_death + delta for diagram to be more
# readable
infinity = max_death + delta
+ axis_end = max_death + delta / 2
axis_start = min_birth - delta
- # line display of equation : birth = death
- x = np.linspace(axis_start, infinity, 1000)
- # infinity line and text
- axes.plot(x, x, color="k", linewidth=1.0)
- axes.plot(x, [infinity] * len(x), linewidth=1.0, color="k", alpha=alpha)
- axes.text(axis_start, infinity, r"$\infty$", color="k", alpha=alpha)
# bootstrap band
if band > 0.0:
+ x = np.linspace(axis_start, infinity, 1000)
axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
-
+ # lower diag patch
+ if greyblock:
+ axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey'))
# Draw points in loop
+ pts_at_infty = False # Records presence of pts at infty
for interval in reversed(persistence):
if float(interval[1][1]) != float("inf"):
# Finite death case
@@ -279,10 +314,23 @@ def plot_persistence_diagram(
color=colormap[interval[0]],
)
else:
+ pts_at_infty = True
# Infinite death case for diagram to be nicer
axes.scatter(
interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]
)
+ if pts_at_infty:
+ # infinity line and text
+ axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
+ axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha)
+ # Infinity label
+ yt = axes.get_yticks()
+ yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity
+ yt = np.append(yt, infinity)
+ ytl = ["%.3f" % e for e in yt] # to avoid float precision error
+ ytl[-1] = r'$+\infty$'
+ axes.set_yticks(yt)
+ axes.set_yticklabels(ytl)
if legend:
dimensions = list(set(item[0] for item in persistence))
@@ -293,11 +341,11 @@ def plot_persistence_diagram(
]
)
- axes.set_xlabel("Birth")
- axes.set_ylabel("Death")
+ axes.set_xlabel("Birth", fontsize=fontsize)
+ axes.set_ylabel("Death", fontsize=fontsize)
+ axes.set_title("Persistence diagram", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
- axes.axis([axis_start, infinity, axis_start, infinity + delta])
- axes.set_title("Persistence diagram")
+ axes.axis([axis_start, axis_end, axis_start, infinity + delta/2])
return axes
except ImportError:
@@ -313,18 +361,26 @@ def plot_persistence_density(
dimension=None,
cmap=None,
legend=False,
- axes=None
+ axes=None,
+ fontsize=16,
+ greyblock=False
):
"""This function plots the persistence density from persistence
- values list or from a :doc:`persistence file <fileformats>`. Be
- aware that this function does not distinguish the dimension, it is
+ values list, np.array of shape (N x 2) representing a diagram
+ in a single homology dimension,
+ or from a `persistence diagram <fileformats.html#persistence-diagram>`_ file.
+ Be aware that this function does not distinguish the dimension, it is
up to you to select the required one. This function also does not handle
degenerate data set (scipy correlation matrix inversion can fail).
- :param persistence: Persistence intervals values list grouped by dimension.
- :type persistence: list of tuples(dimension, tuple(birth, death)).
- :param persistence_file: A :doc:`persistence file <fileformats>`
- style name (reset persistence if both are set).
+ :Requires: `SciPy <installation.html#scipy>`_
+
+ :param persistence: Persistence intervals values list.
+ Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death))
+ or an array of (birth, death).
+ :param persistence_file: A `persistence diagram <fileformats.html#persistence-diagram>`_
+ file style name (reset persistence if both are set).
:type persistence_file: string
:param nbins: Evaluate a gaussian kde on a regular grid of nbins x
nbins over data extents (default is 300)
@@ -355,11 +411,20 @@ def plot_persistence_density(
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:type axes: `matplotlib.axes.Axes`
+ :param fontsize: Fontsize to use in axis.
+ :type fontsize: int
+ :param greyblock: if we want to plot a grey patch on the lower half plane
+ for nicer rendering. Default False.
+ :type greyblock: boolean
:returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
+ import matplotlib.patches as mpatches
from scipy.stats import kde
+ from matplotlib import rc
+ plt.rc('text', usetex=True)
+ plt.rc('font', family='serif')
if persistence_file != "":
if dimension is None:
@@ -374,6 +439,7 @@ def plot_persistence_density(
return None
if len(persistence) > 0:
+ persistence = _array_handler(persistence)
persistence_dim = np.array(
[
(dim_interval[1][0], dim_interval[1][1])
@@ -418,12 +484,16 @@ def plot_persistence_density(
# Make the plot
img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap)
+ if greyblock:
+ axes.add_patch(mpatches.Polygon([[birth.min(), birth.min()], [death.max(), birth.min()], [death.max(), death.max()]], fill=True, color='lightgrey'))
+
if legend:
plt.colorbar(img, ax=axes)
- axes.set_xlabel("Birth")
- axes.set_ylabel("Death")
- axes.set_title("Persistence density")
+ axes.set_xlabel("Birth", fontsize=fontsize)
+ axes.set_ylabel("Death", fontsize=fontsize)
+ axes.set_title("Persistence density", fontsize=fontsize)
+
return axes
except ImportError: