summaryrefslogtreecommitdiff
path: root/cython/cython/persistence_graphical_tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'cython/cython/persistence_graphical_tools.py')
-rwxr-xr-xcython/cython/persistence_graphical_tools.py76
1 files changed, 63 insertions, 13 deletions
diff --git a/cython/cython/persistence_graphical_tools.py b/cython/cython/persistence_graphical_tools.py
index a984633e..fb837e29 100755
--- a/cython/cython/persistence_graphical_tools.py
+++ b/cython/cython/persistence_graphical_tools.py
@@ -1,11 +1,12 @@
import matplotlib.pyplot as plt
import numpy as np
+import os
"""This file is part of the Gudhi Library. The Gudhi library
(Geometric Understanding in Higher Dimensions) is a generic C++
library for computational topology.
- Author(s): Vincent Rouvreau
+ Author(s): Vincent Rouvreau, Bertrand Michel
Copyright (C) 2016 INRIA
@@ -23,15 +24,17 @@ import numpy as np
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
-__author__ = "Vincent Rouvreau"
+__author__ = "Vincent Rouvreau, Bertrand Michel"
__copyright__ = "Copyright (C) 2016 INRIA"
__license__ = "GPL v3"
-def __min_birth_max_death(persistence):
+def __min_birth_max_death(persistence, band_boot=0.):
"""This function returns (min_birth, max_death) from the persistence.
:param persistence: The persistence to plot.
:type persistence: list of tuples(dimension, tuple(birth, death)).
+ :param band_boot: bootstrap band
+ :type band_boot: float.
:returns: (float, float) -- (min_birth, max_death).
"""
# Look for minimum birth date and maximum death date for plot optimisation
@@ -45,6 +48,8 @@ def __min_birth_max_death(persistence):
max_death = float(interval[1][0])
if float(interval[1][0]) < min_birth:
min_birth = float(interval[1][0])
+ if band_boot > 0.:
+ max_death += band_boot
return (min_birth, max_death)
"""
@@ -59,7 +64,7 @@ def show_palette_values(alpha=0.6):
:param alpha: alpha value in [0.0, 1.0] for horizontal bars (default is 0.6).
:type alpha: float.
- :returns: plot -- An horizontal bar plot of dimensions color.
+ :returns: plot the dimension palette values.
"""
colors = []
for color in palette:
@@ -70,18 +75,38 @@ def show_palette_values(alpha=0.6):
plt.barh(y_pos, y_pos + 1, align='center', alpha=alpha, color=colors)
plt.ylabel('Dimension')
plt.title('Dimension palette values')
+ return plt
- plt.show()
-
-def plot_persistence_barcode(persistence, alpha=0.6):
+def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6, max_barcodes=0):
"""This function plots the persistence bar code.
:param persistence: The persistence to plot.
:type persistence: list of tuples(dimension, tuple(birth, death)).
+ :param persistence_file: A persistence file style name (reset persistence if both are set).
+ :type persistence_file: string
:param alpha: alpha value in [0.0, 1.0] for horizontal bars (default is 0.6).
:type alpha: float.
+ :param max_barcodes: number of maximal barcodes to be displayed
+ (persistence will be sorted by life time if max_barcodes is set)
+ :type max_barcodes: int.
:returns: plot -- An horizontal bar plot of persistence.
"""
+ if persistence_file is not '':
+ if os.path.isfile(persistence_file):
+ # Reset persistence
+ persistence = []
+ diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
+ for key in diag.keys():
+ for persistence_interval in diag[key]:
+ persistence.append((key, persistence_interval))
+ else:
+ print("file " + persistence_file + " not found.")
+ return None
+
+ if max_barcodes > 0 and max_barcodes < len(persistence):
+ # Sort by life time, then takes only the max_plots elements
+ persistence = sorted(persistence, key=lambda life_time: life_time[1][1]-life_time[1][0], reverse=True)[:max_barcodes]
+
(min_birth, max_death) = __min_birth_max_death(persistence)
ind = 0
delta = ((max_death - min_birth) / 10.0)
@@ -106,18 +131,40 @@ def plot_persistence_barcode(persistence, alpha=0.6):
plt.title('Persistence barcode')
# Ends plot on infinity value and starts a little bit before min_birth
plt.axis([axis_start, infinity, 0, ind])
- plt.show()
+ return plt
-def plot_persistence_diagram(persistence, alpha=0.6):
- """This function plots the persistence diagram.
+def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6, band_boot=0., max_plots=0):
+ """This function plots the persistence diagram with an optional confidence band.
:param persistence: The persistence to plot.
:type persistence: list of tuples(dimension, tuple(birth, death)).
+ :param persistence_file: A persistence file style name (reset persistence if both are set).
+ :type persistence_file: string
:param alpha: alpha value in [0.0, 1.0] for points and horizontal infinity line (default is 0.6).
:type alpha: float.
- :returns: plot -- An diagram plot of persistence.
+ :param band_boot: bootstrap band (not displayed if :math:`\leq` 0.)
+ :type band_boot: float.
+ :param max_plots: number of maximal plots to be displayed
+ :type max_plots: int.
+ :returns: plot -- A diagram plot of persistence.
"""
- (min_birth, max_death) = __min_birth_max_death(persistence)
+ if persistence_file is not '':
+ if os.path.isfile(persistence_file):
+ # Reset persistence
+ persistence = []
+ diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
+ for key in diag.keys():
+ for persistence_interval in diag[key]:
+ persistence.append((key, persistence_interval))
+ else:
+ print("file " + persistence_file + " not found.")
+ return None
+
+ if max_plots > 0 and max_plots < len(persistence):
+ # Sort by life time, then takes only the max_plots elements
+ persistence = sorted(persistence, key=lambda life_time: life_time[1][1]-life_time[1][0], reverse=True)[:max_plots]
+
+ (min_birth, max_death) = __min_birth_max_death(persistence, band_boot)
ind = 0
delta = ((max_death - min_birth) / 10.0)
# Replace infinity values with max_death + delta for diagram to be more
@@ -131,6 +178,9 @@ def plot_persistence_diagram(persistence, alpha=0.6):
plt.plot(x, x, color='k', linewidth=1.0)
plt.plot(x, [infinity] * len(x), linewidth=1.0, color='k', alpha=alpha)
plt.text(axis_start, infinity, r'$\infty$', color='k', alpha=alpha)
+ # bootstrap band
+ if band_boot > 0.:
+ plt.fill_between(x, x, x+band_boot, alpha=alpha, facecolor='red')
# Draw points in loop
for interval in reversed(persistence):
@@ -149,4 +199,4 @@ def plot_persistence_diagram(persistence, alpha=0.6):
plt.ylabel('Death')
# Ends plot on infinity value and starts a little bit before min_birth
plt.axis([axis_start, infinity, axis_start, infinity + delta])
- plt.show()
+ return plt