summaryrefslogtreecommitdiff
path: root/src/python/gudhi/persistence_graphical_tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/persistence_graphical_tools.py')
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py163
1 files changed, 86 insertions, 77 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 181bc8ea..246280de 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -1,3 +1,12 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Vincent Rouvreau, Bertrand Michel
+#
+# Copyright (C) 2016 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
from os import path
from math import isfinite
import numpy as np
@@ -5,16 +14,6 @@ import numpy as np
from gudhi.reader_utils import read_persistence_intervals_in_dimension
from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension
-""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
- See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
- Author(s): Vincent Rouvreau, Bertrand Michel
-
- Copyright (C) 2016 Inria
-
- Modification(s):
- - YYYY/MM Author: Description of the modification
-"""
-
__author__ = "Vincent Rouvreau, Bertrand Michel"
__copyright__ = "Copyright (C) 2016 Inria"
__license__ = "MIT"
@@ -44,27 +43,6 @@ def __min_birth_max_death(persistence, band=0.0):
max_death += band
return (min_birth, max_death)
-
-"""
-Only 13 colors for the palette
-"""
-palette = [
- "#ff0000",
- "#00ff00",
- "#0000ff",
- "#00ffff",
- "#ff00ff",
- "#ffff00",
- "#000000",
- "#880000",
- "#008800",
- "#000088",
- "#888800",
- "#880088",
- "#008888",
-]
-
-
def plot_persistence_barcode(
persistence=[],
persistence_file="",
@@ -73,6 +51,8 @@ def plot_persistence_barcode(
max_barcodes=1000,
inf_delta=0.1,
legend=False,
+ colormap=None,
+ axes=None
):
"""This function plots the persistence bar code from persistence values list
or from a :doc:`persistence file <fileformats>`.
@@ -95,14 +75,19 @@ def plot_persistence_barcode(
:type inf_delta: float.
:param legend: Display the dimension color legend (default is False).
:type legend: boolean.
- :returns: A matplotlib object containing horizontal bar plot of persistence
- (launch `show()` method on it to display it).
+ :param colormap: A matplotlib-like qualitative colormaps. Default is None
+ which means :code:`matplotlib.cm.Set1.colors`.
+ :type colormap: tuple of colors (3-tuple of float between 0. and 1.).
+ :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
+ a new set of axes.
+ :type axes: `matplotlib.axes.Axes`
+ :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
- if persistence_file is not "":
+ if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
@@ -116,7 +101,7 @@ def plot_persistence_barcode(
print("file " + persistence_file + " not found.")
return None
- if max_barcodes is not 1000:
+ if max_barcodes != 1000:
print("Deprecated parameter. It has been replaced by max_intervals")
max_intervals = max_barcodes
@@ -127,6 +112,11 @@ def plot_persistence_barcode(
key=lambda life_time: life_time[1][1] - life_time[1][0],
reverse=True,
)[:max_intervals]
+
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
+ if axes == None:
+ fig, axes = plt.subplots(1, 1)
persistence = sorted(persistence, key=lambda birth: birth[1][0])
@@ -141,41 +131,43 @@ def plot_persistence_barcode(
for interval in reversed(persistence):
if float(interval[1][1]) != float("inf"):
# Finite death case
- plt.barh(
+ axes.barh(
ind,
(interval[1][1] - interval[1][0]),
height=0.8,
left=interval[1][0],
alpha=alpha,
- color=palette[interval[0]],
+ color=colormap[interval[0]],
linewidth=0,
)
else:
# Infinite death case for diagram to be nicer
- plt.barh(
+ axes.barh(
ind,
(infinity - interval[1][0]),
height=0.8,
left=interval[1][0],
alpha=alpha,
- color=palette[interval[0]],
+ color=colormap[interval[0]],
linewidth=0,
)
ind = ind + 1
if legend:
dimensions = list(set(item[0] for item in persistence))
- plt.legend(
+ axes.legend(
handles=[
- mpatches.Patch(color=palette[dim], label=str(dim))
+ mpatches.Patch(color=colormap[dim], label=str(dim))
for dim in dimensions
],
loc="lower right",
)
- plt.title("Persistence barcode")
+
+ axes.set_title("Persistence barcode")
+
# Ends plot on infinity value and starts a little bit before min_birth
- plt.axis([axis_start, infinity, 0, ind])
- return plt
+ axes.axis([axis_start, infinity, 0, ind])
+ return axes
except ImportError:
print("This function is not available, you may be missing matplotlib.")
@@ -190,6 +182,8 @@ def plot_persistence_diagram(
max_plots=1000,
inf_delta=0.1,
legend=False,
+ colormap=None,
+ axes=None
):
"""This function plots the persistence diagram from persistence values
list or from a :doc:`persistence file <fileformats>`.
@@ -214,14 +208,19 @@ def plot_persistence_diagram(
:type inf_delta: float.
:param legend: Display the dimension color legend (default is False).
:type legend: boolean.
- :returns: A matplotlib object containing diagram plot of persistence
- (launch `show()` method on it to display it).
+ :param colormap: A matplotlib-like qualitative colormaps. Default is None
+ which means :code:`matplotlib.cm.Set1.colors`.
+ :type colormap: tuple of colors (3-tuple of float between 0. and 1.).
+ :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
+ a new set of axes.
+ :type axes: `matplotlib.axes.Axes`
+ :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
- if persistence_file is not "":
+ if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
@@ -235,7 +234,7 @@ def plot_persistence_diagram(
print("file " + persistence_file + " not found.")
return None
- if max_plots is not 1000:
+ if max_plots != 1000:
print("Deprecated parameter. It has been replaced by max_intervals")
max_intervals = max_plots
@@ -247,6 +246,11 @@ def plot_persistence_diagram(
reverse=True,
)[:max_intervals]
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
+ if axes == None:
+ fig, axes = plt.subplots(1, 1)
+
(min_birth, max_death) = __min_birth_max_death(persistence, band)
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for diagram to be more
@@ -257,44 +261,44 @@ def plot_persistence_diagram(
# line display of equation : birth = death
x = np.linspace(axis_start, infinity, 1000)
# infinity line and text
- plt.plot(x, x, color="k", linewidth=1.0)
- plt.plot(x, [infinity] * len(x), linewidth=1.0, color="k", alpha=alpha)
- plt.text(axis_start, infinity, r"$\infty$", color="k", alpha=alpha)
+ axes.plot(x, x, color="k", linewidth=1.0)
+ axes.plot(x, [infinity] * len(x), linewidth=1.0, color="k", alpha=alpha)
+ axes.text(axis_start, infinity, r"$\infty$", color="k", alpha=alpha)
# bootstrap band
if band > 0.0:
- plt.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
+ axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
# Draw points in loop
for interval in reversed(persistence):
if float(interval[1][1]) != float("inf"):
# Finite death case
- plt.scatter(
+ axes.scatter(
interval[1][0],
interval[1][1],
alpha=alpha,
- color=palette[interval[0]],
+ color=colormap[interval[0]],
)
else:
# Infinite death case for diagram to be nicer
- plt.scatter(
- interval[1][0], infinity, alpha=alpha, color=palette[interval[0]]
+ axes.scatter(
+ interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]
)
if legend:
dimensions = list(set(item[0] for item in persistence))
- plt.legend(
+ axes.legend(
handles=[
- mpatches.Patch(color=palette[dim], label=str(dim))
+ mpatches.Patch(color=colormap[dim], label=str(dim))
for dim in dimensions
]
)
- plt.title("Persistence diagram")
- plt.xlabel("Birth")
- plt.ylabel("Death")
+ axes.set_xlabel("Birth")
+ axes.set_ylabel("Death")
# Ends plot on infinity value and starts a little bit before min_birth
- plt.axis([axis_start, infinity, axis_start, infinity + delta])
- return plt
+ axes.axis([axis_start, infinity, axis_start, infinity + delta])
+ axes.set_title("Persistence diagram")
+ return axes
except ImportError:
print("This function is not available, you may be missing matplotlib.")
@@ -309,6 +313,7 @@ def plot_persistence_density(
dimension=None,
cmap=None,
legend=False,
+ axes=None
):
"""This function plots the persistence density from persistence
values list or from a :doc:`persistence file <fileformats>`. Be
@@ -347,14 +352,16 @@ def plot_persistence_density(
:type cmap: cf. matplotlib colormap.
:param legend: Display the color bar values (default is False).
:type legend: boolean.
- :returns: A matplotlib object containing diagram plot of persistence
- (launch `show()` method on it to display it).
+ :param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
+ a new set of axes.
+ :type axes: `matplotlib.axes.Axes`
+ :returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
from scipy.stats import kde
- if persistence_file is not "":
+ if persistence_file != "":
if dimension is None:
# All dimension case
dimension = -1
@@ -362,7 +369,6 @@ def plot_persistence_density(
persistence_dim = read_persistence_intervals_in_dimension(
persistence_file=persistence_file, only_this_dim=dimension
)
- print(persistence_dim)
else:
print("file " + persistence_file + " not found.")
return None
@@ -391,9 +397,15 @@ def plot_persistence_density(
birth = persistence_dim[:, 0]
death = persistence_dim[:, 1]
+ # default cmap value cannot be done at argument definition level as matplotlib is not yet defined.
+ if cmap is None:
+ cmap = plt.cm.hot_r
+ if axes == None:
+ fig, axes = plt.subplots(1, 1)
+
# line display of equation : birth = death
x = np.linspace(death.min(), birth.max(), 1000)
- plt.plot(x, x, color="k", linewidth=1.0)
+ axes.plot(x, x, color="k", linewidth=1.0)
# Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
k = kde.gaussian_kde([birth, death], bw_method=bw_method)
@@ -403,19 +415,16 @@ def plot_persistence_density(
]
zi = k(np.vstack([xi.flatten(), yi.flatten()]))
- # default cmap value cannot be done at argument definition level as matplotlib is not yet defined.
- if cmap is None:
- cmap = plt.cm.hot_r
# Make the plot
- plt.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap)
+ img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap)
if legend:
- plt.colorbar()
+ plt.colorbar(img, ax=axes)
- plt.title("Persistence density")
- plt.xlabel("Birth")
- plt.ylabel("Death")
- return plt
+ axes.set_xlabel("Birth")
+ axes.set_ylabel("Death")
+ axes.set_title("Persistence density")
+ return axes
except ImportError:
print(