summaryrefslogtreecommitdiff
path: root/src/python/gudhi/persistence_graphical_tools.py
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-02-25 18:19:24 +0100
committertlacombe <lacombe1993@gmail.com>2020-02-25 18:19:24 +0100
commitcdcd2904a1c682625670a62608fd781bfd571516 (patch)
treec46385ca997d25f2bbdb66bb284bab75f1ccd28b /src/python/gudhi/persistence_graphical_tools.py
parent835a831007196a4d93e57659ab8d3cdb28a4ef92 (diff)
solved scale issue and removed title/aspect as functions return ax
Diffstat (limited to 'src/python/gudhi/persistence_graphical_tools.py')
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py77
1 files changed, 60 insertions, 17 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 246280de..8ddfdba8 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -5,6 +5,7 @@
# Copyright (C) 2016 Inria
#
# Modification(s):
+# - 2020/02 Theo Lacombe: Added more options for improved rendering and more flexibility.
# - YYYY/MM Author: Description of the modification
from os import path
@@ -43,6 +44,7 @@ def __min_birth_max_death(persistence, band=0.0):
max_death += band
return (min_birth, max_death)
+
def plot_persistence_barcode(
persistence=[],
persistence_file="",
@@ -52,7 +54,8 @@ def plot_persistence_barcode(
inf_delta=0.1,
legend=False,
colormap=None,
- axes=None
+ axes=None,
+ fontsize=16,
):
"""This function plots the persistence bar code from persistence values list
or from a :doc:`persistence file <fileformats>`.
@@ -81,11 +84,16 @@ def plot_persistence_barcode(
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:type axes: `matplotlib.axes.Axes`
+ :param fontsize: Fontsize to use in axis.
+ :type fontsize: int
:returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
+ from matplotlib import rc
+ plt.rc('text', usetex=True)
+ plt.rc('font', family='serif')
if persistence_file != "":
if path.isfile(persistence_file):
@@ -163,7 +171,7 @@ def plot_persistence_barcode(
loc="lower right",
)
- axes.set_title("Persistence barcode")
+ axes.set_title("Persistence barcode", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
axes.axis([axis_start, infinity, 0, ind])
@@ -183,7 +191,9 @@ def plot_persistence_diagram(
inf_delta=0.1,
legend=False,
colormap=None,
- axes=None
+ axes=None,
+ fontsize=16,
+ greyblock=True
):
"""This function plots the persistence diagram from persistence values
list or from a :doc:`persistence file <fileformats>`.
@@ -214,11 +224,19 @@ def plot_persistence_diagram(
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:type axes: `matplotlib.axes.Axes`
+ :param fontsize: Fontsize to use in axis.
+ :type fontsize: int
+ :param greyblock: if we want to plot a grey patch on the lower half plane for nicer rendering. Default True.
+ :type greyblock: boolean
:returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
+ from matplotlib import rc
+ plt.rc('text', usetex=True)
+ plt.rc('font', family='serif')
+
if persistence_file != "":
if path.isfile(persistence_file):
@@ -256,18 +274,27 @@ def plot_persistence_diagram(
# Replace infinity values with max_death + delta for diagram to be more
# readable
infinity = max_death + delta
+ axis_end = max_death + delta / 2
axis_start = min_birth - delta
- # line display of equation : birth = death
- x = np.linspace(axis_start, infinity, 1000)
# infinity line and text
- axes.plot(x, x, color="k", linewidth=1.0)
- axes.plot(x, [infinity] * len(x), linewidth=1.0, color="k", alpha=alpha)
- axes.text(axis_start, infinity, r"$\infty$", color="k", alpha=alpha)
+ axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
+ axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha)
+ # Infinity label
+ yt = axes.get_yticks()
+ yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity
+ yt = np.append(yt, infinity)
+ ytl = ["%.3f" % e for e in yt] # to avoid float precision error
+ ytl[-1] = r'$+\infty$'
+ axes.set_yticks(yt)
+ axes.set_yticklabels(ytl)
# bootstrap band
if band > 0.0:
+ x = np.linspace(axis_start, infinity, 1000)
axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
-
+ # lower diag patch
+ if greyblock:
+ axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey'))
# Draw points in loop
for interval in reversed(persistence):
if float(interval[1][1]) != float("inf"):
@@ -293,11 +320,11 @@ def plot_persistence_diagram(
]
)
- axes.set_xlabel("Birth")
- axes.set_ylabel("Death")
+ axes.set_xlabel("Birth", fontsize=fontsize)
+ axes.set_ylabel("Death", fontsize=fontsize)
+ axes.set_title("Persistence diagram", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
- axes.axis([axis_start, infinity, axis_start, infinity + delta])
- axes.set_title("Persistence diagram")
+ axes.axis([axis_start, axis_end, axis_start, infinity + delta/2])
return axes
except ImportError:
@@ -313,7 +340,9 @@ def plot_persistence_density(
dimension=None,
cmap=None,
legend=False,
- axes=None
+ axes=None,
+ fontsize=16,
+ greyblock=True
):
"""This function plots the persistence density from persistence
values list or from a :doc:`persistence file <fileformats>`. Be
@@ -355,11 +384,21 @@ def plot_persistence_density(
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:type axes: `matplotlib.axes.Axes`
+ :param fontsize: Fontsize to use in axis.
+ :type fontsize: int
+ :param greyblock: if we want to plot a grey patch on the lower half plane
+ for nicer rendering. Default True.
+ :type greyblock: boolean
:returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
+ import matplotlib.patches as mpatches
from scipy.stats import kde
+ from matplotlib import rc
+ plt.rc('text', usetex=True)
+ plt.rc('font', family='serif')
+
if persistence_file != "":
if dimension is None:
@@ -418,12 +457,16 @@ def plot_persistence_density(
# Make the plot
img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap)
+ if greyblock:
+ axes.add_patch(mpatches.Polygon([[birth.min(), birth.min()], [death.max(), birth.min()], [death.max(), death.max()]], fill=True, color='lightgrey'))
+
if legend:
plt.colorbar(img, ax=axes)
- axes.set_xlabel("Birth")
- axes.set_ylabel("Death")
- axes.set_title("Persistence density")
+ axes.set_xlabel("Birth", fontsize=fontsize)
+ axes.set_ylabel("Death", fontsize=fontsize)
+ axes.set_title("Persistence density", fontsize=fontsize)
+
return axes
except ImportError: