diff options
Diffstat (limited to 'src/cython/cython/persistence_graphical_tools.py')
-rwxr-xr-x | src/cython/cython/persistence_graphical_tools.py | 152 |
1 files changed, 152 insertions, 0 deletions
diff --git a/src/cython/cython/persistence_graphical_tools.py b/src/cython/cython/persistence_graphical_tools.py new file mode 100755 index 00000000..a984633e --- /dev/null +++ b/src/cython/cython/persistence_graphical_tools.py @@ -0,0 +1,152 @@ +import matplotlib.pyplot as plt +import numpy as np + +"""This file is part of the Gudhi Library. The Gudhi library + (Geometric Understanding in Higher Dimensions) is a generic C++ + library for computational topology. + + Author(s): Vincent Rouvreau + + Copyright (C) 2016 INRIA + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +""" + +__author__ = "Vincent Rouvreau" +__copyright__ = "Copyright (C) 2016 INRIA" +__license__ = "GPL v3" + +def __min_birth_max_death(persistence): + """This function returns (min_birth, max_death) from the persistence. + + :param persistence: The persistence to plot. + :type persistence: list of tuples(dimension, tuple(birth, death)). + :returns: (float, float) -- (min_birth, max_death). + """ + # Look for minimum birth date and maximum death date for plot optimisation + max_death = 0 + min_birth = persistence[0][1][0] + for interval in reversed(persistence): + if float(interval[1][1]) != float('inf'): + if float(interval[1][1]) > max_death: + max_death = float(interval[1][1]) + if float(interval[1][0]) > max_death: + max_death = float(interval[1][0]) + if float(interval[1][0]) < min_birth: + min_birth = float(interval[1][0]) + return (min_birth, max_death) + +""" +Only 13 colors for the palette +""" +palette = ['#ff0000', '#00ff00', '#0000ff', '#00ffff', '#ff00ff', '#ffff00', + '#000000', '#880000', '#008800', '#000088', '#888800', '#880088', + '#008888'] + +def show_palette_values(alpha=0.6): + """This function shows palette color values in function of the dimension. + + :param alpha: alpha value in [0.0, 1.0] for horizontal bars (default is 0.6). + :type alpha: float. + :returns: plot -- An horizontal bar plot of dimensions color. + """ + colors = [] + for color in palette: + colors.append(color) + + y_pos = np.arange(len(palette)) + + plt.barh(y_pos, y_pos + 1, align='center', alpha=alpha, color=colors) + plt.ylabel('Dimension') + plt.title('Dimension palette values') + + plt.show() + +def plot_persistence_barcode(persistence, alpha=0.6): + """This function plots the persistence bar code. + + :param persistence: The persistence to plot. + :type persistence: list of tuples(dimension, tuple(birth, death)). + :param alpha: alpha value in [0.0, 1.0] for horizontal bars (default is 0.6). + :type alpha: float. + :returns: plot -- An horizontal bar plot of persistence. + """ + (min_birth, max_death) = __min_birth_max_death(persistence) + ind = 0 + delta = ((max_death - min_birth) / 10.0) + # Replace infinity values with max_death + delta for bar code to be more + # readable + infinity = max_death + delta + axis_start = min_birth - delta + # Draw horizontal bars in loop + for interval in reversed(persistence): + if float(interval[1][1]) != float('inf'): + # Finite death case + plt.barh(ind, (interval[1][1] - interval[1][0]), height=0.8, + left = interval[1][0], alpha=alpha, + color = palette[interval[0]]) + else: + # Infinite death case for diagram to be nicer + plt.barh(ind, (infinity - interval[1][0]), height=0.8, + left = interval[1][0], alpha=alpha, + color = palette[interval[0]]) + ind = ind + 1 + + plt.title('Persistence barcode') + # Ends plot on infinity value and starts a little bit before min_birth + plt.axis([axis_start, infinity, 0, ind]) + plt.show() + +def plot_persistence_diagram(persistence, alpha=0.6): + """This function plots the persistence diagram. + + :param persistence: The persistence to plot. + :type persistence: list of tuples(dimension, tuple(birth, death)). + :param alpha: alpha value in [0.0, 1.0] for points and horizontal infinity line (default is 0.6). + :type alpha: float. + :returns: plot -- An diagram plot of persistence. + """ + (min_birth, max_death) = __min_birth_max_death(persistence) + ind = 0 + delta = ((max_death - min_birth) / 10.0) + # Replace infinity values with max_death + delta for diagram to be more + # readable + infinity = max_death + delta + axis_start = min_birth - delta + + # line display of equation : birth = death + x = np.linspace(axis_start, infinity, 1000) + # infinity line and text + plt.plot(x, x, color='k', linewidth=1.0) + plt.plot(x, [infinity] * len(x), linewidth=1.0, color='k', alpha=alpha) + plt.text(axis_start, infinity, r'$\infty$', color='k', alpha=alpha) + + # Draw points in loop + for interval in reversed(persistence): + if float(interval[1][1]) != float('inf'): + # Finite death case + plt.scatter(interval[1][0], interval[1][1], alpha=alpha, + color = palette[interval[0]]) + else: + # Infinite death case for diagram to be nicer + plt.scatter(interval[1][0], infinity, alpha=alpha, + color = palette[interval[0]]) + ind = ind + 1 + + plt.title('Persistence diagram') + plt.xlabel('Birth') + plt.ylabel('Death') + # Ends plot on infinity value and starts a little bit before min_birth + plt.axis([axis_start, infinity, axis_start, infinity + delta]) + plt.show() |