summaryrefslogtreecommitdiff
path: root/src/cython/cython/persistence_graphical_tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/cython/cython/persistence_graphical_tools.py')
-rw-r--r--src/cython/cython/persistence_graphical_tools.py233
1 files changed, 161 insertions, 72 deletions
diff --git a/src/cython/cython/persistence_graphical_tools.py b/src/cython/cython/persistence_graphical_tools.py
index ead81d30..34803222 100644
--- a/src/cython/cython/persistence_graphical_tools.py
+++ b/src/cython/cython/persistence_graphical_tools.py
@@ -16,7 +16,8 @@ __author__ = "Vincent Rouvreau, Bertrand Michel"
__copyright__ = "Copyright (C) 2016 Inria"
__license__ = "MIT"
-def __min_birth_max_death(persistence, band=0.):
+
+def __min_birth_max_death(persistence, band=0.0):
"""This function returns (min_birth, max_death) from the persistence.
:param persistence: The persistence to plot.
@@ -29,27 +30,47 @@ def __min_birth_max_death(persistence, band=0.):
max_death = 0
min_birth = persistence[0][1][0]
for interval in reversed(persistence):
- if float(interval[1][1]) != float('inf'):
+ if float(interval[1][1]) != float("inf"):
if float(interval[1][1]) > max_death:
max_death = float(interval[1][1])
if float(interval[1][0]) > max_death:
max_death = float(interval[1][0])
if float(interval[1][0]) < min_birth:
min_birth = float(interval[1][0])
- if band > 0.:
+ if band > 0.0:
max_death += band
return (min_birth, max_death)
+
"""
Only 13 colors for the palette
"""
-palette = ['#ff0000', '#00ff00', '#0000ff', '#00ffff', '#ff00ff', '#ffff00',
- '#000000', '#880000', '#008800', '#000088', '#888800', '#880088',
- '#008888']
-
-def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6,
- max_intervals=1000, max_barcodes=1000,
- inf_delta=0.1, legend=False):
+palette = [
+ "#ff0000",
+ "#00ff00",
+ "#0000ff",
+ "#00ffff",
+ "#ff00ff",
+ "#ffff00",
+ "#000000",
+ "#880000",
+ "#008800",
+ "#000088",
+ "#888800",
+ "#880088",
+ "#008888",
+]
+
+
+def plot_persistence_barcode(
+ persistence=[],
+ persistence_file="",
+ alpha=0.6,
+ max_intervals=1000,
+ max_barcodes=1000,
+ inf_delta=0.1,
+ legend=False,
+):
"""This function plots the persistence bar code from persistence values list
or from a :doc:`persistence file <fileformats>`.
@@ -78,11 +99,13 @@ def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6,
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
- if persistence_file is not '':
+ if persistence_file is not "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
- diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
+ diag = read_persistence_intervals_grouped_by_dimension(
+ persistence_file=persistence_file
+ )
for key in diag.keys():
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
@@ -91,44 +114,62 @@ def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6,
return None
if max_barcodes is not 1000:
- print('Deprecated parameter. It has been replaced by max_intervals')
+ print("Deprecated parameter. It has been replaced by max_intervals")
max_intervals = max_barcodes
if max_intervals > 0 and max_intervals < len(persistence):
# Sort by life time, then takes only the max_intervals elements
- persistence = sorted(persistence, key=lambda life_time: life_time[1][1]-life_time[1][0], reverse=True)[:max_intervals]
+ persistence = sorted(
+ persistence,
+ key=lambda life_time: life_time[1][1] - life_time[1][0],
+ reverse=True,
+ )[:max_intervals]
persistence = sorted(persistence, key=lambda birth: birth[1][0])
(min_birth, max_death) = __min_birth_max_death(persistence)
ind = 0
- delta = ((max_death - min_birth) * inf_delta)
+ delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for bar code to be more
# readable
infinity = max_death + delta
axis_start = min_birth - delta
# Draw horizontal bars in loop
for interval in reversed(persistence):
- if float(interval[1][1]) != float('inf'):
+ if float(interval[1][1]) != float("inf"):
# Finite death case
- plt.barh(ind, (interval[1][1] - interval[1][0]), height=0.8,
- left = interval[1][0], alpha=alpha,
- color = palette[interval[0]],
- linewidth=0)
+ plt.barh(
+ ind,
+ (interval[1][1] - interval[1][0]),
+ height=0.8,
+ left=interval[1][0],
+ alpha=alpha,
+ color=palette[interval[0]],
+ linewidth=0,
+ )
else:
# Infinite death case for diagram to be nicer
- plt.barh(ind, (infinity - interval[1][0]), height=0.8,
- left = interval[1][0], alpha=alpha,
- color = palette[interval[0]],
- linewidth=0)
+ plt.barh(
+ ind,
+ (infinity - interval[1][0]),
+ height=0.8,
+ left=interval[1][0],
+ alpha=alpha,
+ color=palette[interval[0]],
+ linewidth=0,
+ )
ind = ind + 1
if legend:
dimensions = list(set(item[0] for item in persistence))
- plt.legend(handles=[mpatches.Patch(color=palette[dim],
- label=str(dim)) for dim in dimensions],
- loc='lower right')
- plt.title('Persistence barcode')
+ plt.legend(
+ handles=[
+ mpatches.Patch(color=palette[dim], label=str(dim))
+ for dim in dimensions
+ ],
+ loc="lower right",
+ )
+ plt.title("Persistence barcode")
# Ends plot on infinity value and starts a little bit before min_birth
plt.axis([axis_start, infinity, 0, ind])
return plt
@@ -136,8 +177,17 @@ def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6,
except ImportError:
print("This function is not available, you may be missing matplotlib.")
-def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6,
- band=0., max_intervals=1000, max_plots=1000, inf_delta=0.1, legend=False):
+
+def plot_persistence_diagram(
+ persistence=[],
+ persistence_file="",
+ alpha=0.6,
+ band=0.0,
+ max_intervals=1000,
+ max_plots=1000,
+ inf_delta=0.1,
+ legend=False,
+):
"""This function plots the persistence diagram from persistence values
list or from a :doc:`persistence file <fileformats>`.
@@ -168,11 +218,13 @@ def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6,
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
- if persistence_file is not '':
+ if persistence_file is not "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
- diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
+ diag = read_persistence_intervals_grouped_by_dimension(
+ persistence_file=persistence_file
+ )
for key in diag.keys():
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
@@ -181,15 +233,19 @@ def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6,
return None
if max_plots is not 1000:
- print('Deprecated parameter. It has been replaced by max_intervals')
+ print("Deprecated parameter. It has been replaced by max_intervals")
max_intervals = max_plots
if max_intervals > 0 and max_intervals < len(persistence):
# Sort by life time, then takes only the max_intervals elements
- persistence = sorted(persistence, key=lambda life_time: life_time[1][1]-life_time[1][0], reverse=True)[:max_intervals]
+ persistence = sorted(
+ persistence,
+ key=lambda life_time: life_time[1][1] - life_time[1][0],
+ reverse=True,
+ )[:max_intervals]
(min_birth, max_death) = __min_birth_max_death(persistence, band)
- delta = ((max_death - min_birth) * inf_delta)
+ delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for diagram to be more
# readable
infinity = max_death + delta
@@ -198,31 +254,41 @@ def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6,
# line display of equation : birth = death
x = np.linspace(axis_start, infinity, 1000)
# infinity line and text
- plt.plot(x, x, color='k', linewidth=1.0)
- plt.plot(x, [infinity] * len(x), linewidth=1.0, color='k', alpha=alpha)
- plt.text(axis_start, infinity, r'$\infty$', color='k', alpha=alpha)
+ plt.plot(x, x, color="k", linewidth=1.0)
+ plt.plot(x, [infinity] * len(x), linewidth=1.0, color="k", alpha=alpha)
+ plt.text(axis_start, infinity, r"$\infty$", color="k", alpha=alpha)
# bootstrap band
- if band > 0.:
- plt.fill_between(x, x, x+band, alpha=alpha, facecolor='red')
+ if band > 0.0:
+ plt.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
# Draw points in loop
for interval in reversed(persistence):
- if float(interval[1][1]) != float('inf'):
+ if float(interval[1][1]) != float("inf"):
# Finite death case
- plt.scatter(interval[1][0], interval[1][1], alpha=alpha,
- color = palette[interval[0]])
+ plt.scatter(
+ interval[1][0],
+ interval[1][1],
+ alpha=alpha,
+ color=palette[interval[0]],
+ )
else:
# Infinite death case for diagram to be nicer
- plt.scatter(interval[1][0], infinity, alpha=alpha,
- color = palette[interval[0]])
+ plt.scatter(
+ interval[1][0], infinity, alpha=alpha, color=palette[interval[0]]
+ )
if legend:
dimensions = list(set(item[0] for item in persistence))
- plt.legend(handles=[mpatches.Patch(color=palette[dim], label=str(dim)) for dim in dimensions])
-
- plt.title('Persistence diagram')
- plt.xlabel('Birth')
- plt.ylabel('Death')
+ plt.legend(
+ handles=[
+ mpatches.Patch(color=palette[dim], label=str(dim))
+ for dim in dimensions
+ ]
+ )
+
+ plt.title("Persistence diagram")
+ plt.xlabel("Birth")
+ plt.ylabel("Death")
# Ends plot on infinity value and starts a little bit before min_birth
plt.axis([axis_start, infinity, axis_start, infinity + delta])
return plt
@@ -230,10 +296,17 @@ def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6,
except ImportError:
print("This function is not available, you may be missing matplotlib.")
-def plot_persistence_density(persistence=[], persistence_file='',
- nbins=300, bw_method=None,
- max_intervals=1000, dimension=None,
- cmap=None, legend=False):
+
+def plot_persistence_density(
+ persistence=[],
+ persistence_file="",
+ nbins=300,
+ bw_method=None,
+ max_intervals=1000,
+ dimension=None,
+ cmap=None,
+ legend=False,
+):
"""This function plots the persistence density from persistence
values list or from a :doc:`persistence file <fileformats>`. Be
aware that this function does not distinguish the dimension, it is
@@ -278,39 +351,53 @@ def plot_persistence_density(persistence=[], persistence_file='',
import matplotlib.pyplot as plt
from scipy.stats import kde
- if persistence_file is not '':
+ if persistence_file is not "":
if dimension is None:
# All dimension case
dimension = -1
if path.isfile(persistence_file):
- persistence_dim = read_persistence_intervals_in_dimension(persistence_file=persistence_file,
- only_this_dim=dimension)
+ persistence_dim = read_persistence_intervals_in_dimension(
+ persistence_file=persistence_file, only_this_dim=dimension
+ )
print(persistence_dim)
else:
print("file " + persistence_file + " not found.")
return None
if len(persistence) > 0:
- persistence_dim = np.array([(dim_interval[1][0], dim_interval[1][1]) for dim_interval in persistence if (dim_interval[0] == dimension) or (dimension is None)])
-
- persistence_dim = persistence_dim[np.isfinite(persistence_dim[:,1])]
+ persistence_dim = np.array(
+ [
+ (dim_interval[1][0], dim_interval[1][1])
+ for dim_interval in persistence
+ if (dim_interval[0] == dimension) or (dimension is None)
+ ]
+ )
+
+ persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
if max_intervals > 0 and max_intervals < len(persistence_dim):
# Sort by life time, then takes only the max_intervals elements
- persistence_dim = np.array(sorted(persistence_dim,
- key=lambda life_time: life_time[1]-life_time[0],
- reverse=True)[:max_intervals])
+ persistence_dim = np.array(
+ sorted(
+ persistence_dim,
+ key=lambda life_time: life_time[1] - life_time[0],
+ reverse=True,
+ )[:max_intervals]
+ )
# Set as numpy array birth and death (remove undefined values - inf and NaN)
- birth = persistence_dim[:,0]
- death = persistence_dim[:,1]
+ birth = persistence_dim[:, 0]
+ death = persistence_dim[:, 1]
# line display of equation : birth = death
x = np.linspace(death.min(), birth.max(), 1000)
- plt.plot(x, x, color='k', linewidth=1.0)
+ plt.plot(x, x, color="k", linewidth=1.0)
# Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
- k = kde.gaussian_kde([birth,death], bw_method=bw_method)
- xi, yi = np.mgrid[birth.min():birth.max():nbins*1j, death.min():death.max():nbins*1j]
+ k = kde.gaussian_kde([birth, death], bw_method=bw_method)
+ xi, yi = np.mgrid[
+ birth.min() : birth.max() : nbins * 1j,
+ death.min() : death.max() : nbins * 1j,
+ ]
zi = k(np.vstack([xi.flatten(), yi.flatten()]))
# default cmap value cannot be done at argument definition level as matplotlib is not yet defined.
@@ -322,10 +409,12 @@ def plot_persistence_density(persistence=[], persistence_file='',
if legend:
plt.colorbar()
- plt.title('Persistence density')
- plt.xlabel('Birth')
- plt.ylabel('Death')
+ plt.title("Persistence density")
+ plt.xlabel("Birth")
+ plt.ylabel("Death")
return plt
except ImportError:
- print("This function is not available, you may be missing matplotlib and/or scipy.")
+ print(
+ "This function is not available, you may be missing matplotlib and/or scipy."
+ )