From 8227cd68d5aa7c9eeda5dd474f2536b896b6f491 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Thu, 24 Oct 2019 08:30:06 +0200 Subject: Fix issues #113 (replace 'is not' with '!=') and #109 (replace palette with a more visible one) --- src/python/gudhi/persistence_graphical_tools.py | 57 +++++++++++-------------- 1 file changed, 25 insertions(+), 32 deletions(-) (limited to 'src') diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 181bc8ea..a8e2051b 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -44,27 +44,6 @@ def __min_birth_max_death(persistence, band=0.0): max_death += band return (min_birth, max_death) - -""" -Only 13 colors for the palette -""" -palette = [ - "#ff0000", - "#00ff00", - "#0000ff", - "#00ffff", - "#ff00ff", - "#ffff00", - "#000000", - "#880000", - "#008800", - "#000088", - "#888800", - "#880088", - "#008888", -] - - def plot_persistence_barcode( persistence=[], persistence_file="", @@ -73,6 +52,7 @@ def plot_persistence_barcode( max_barcodes=1000, inf_delta=0.1, legend=False, + colormap=None ): """This function plots the persistence bar code from persistence values list or from a :doc:`persistence file `. @@ -95,6 +75,9 @@ def plot_persistence_barcode( :type inf_delta: float. :param legend: Display the dimension color legend (default is False). :type legend: boolean. + :param colormap: A matplotlib-like qualitative colormaps. Default is None + which means :code:`matplotlib.cm.Set1.colors`. + :type colormap: tuple of colors (3-tuple of float between 0. and 1.). :returns: A matplotlib object containing horizontal bar plot of persistence (launch `show()` method on it to display it). """ @@ -102,7 +85,7 @@ def plot_persistence_barcode( import matplotlib.pyplot as plt import matplotlib.patches as mpatches - if persistence_file is not "": + if persistence_file != "": if path.isfile(persistence_file): # Reset persistence persistence = [] @@ -116,7 +99,7 @@ def plot_persistence_barcode( print("file " + persistence_file + " not found.") return None - if max_barcodes is not 1000: + if max_barcodes != 1000: print("Deprecated parameter. It has been replaced by max_intervals") max_intervals = max_barcodes @@ -127,6 +110,9 @@ def plot_persistence_barcode( key=lambda life_time: life_time[1][1] - life_time[1][0], reverse=True, )[:max_intervals] + + if colormap == None: + colormap = plt.cm.Set1.colors persistence = sorted(persistence, key=lambda birth: birth[1][0]) @@ -147,7 +133,7 @@ def plot_persistence_barcode( height=0.8, left=interval[1][0], alpha=alpha, - color=palette[interval[0]], + color=colormap[interval[0]], linewidth=0, ) else: @@ -158,7 +144,7 @@ def plot_persistence_barcode( height=0.8, left=interval[1][0], alpha=alpha, - color=palette[interval[0]], + color=colormap[interval[0]], linewidth=0, ) ind = ind + 1 @@ -167,7 +153,7 @@ def plot_persistence_barcode( dimensions = list(set(item[0] for item in persistence)) plt.legend( handles=[ - mpatches.Patch(color=palette[dim], label=str(dim)) + mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions ], loc="lower right", @@ -190,6 +176,7 @@ def plot_persistence_diagram( max_plots=1000, inf_delta=0.1, legend=False, + colormap=None ): """This function plots the persistence diagram from persistence values list or from a :doc:`persistence file `. @@ -214,6 +201,9 @@ def plot_persistence_diagram( :type inf_delta: float. :param legend: Display the dimension color legend (default is False). :type legend: boolean. + :param colormap: A matplotlib-like qualitative colormaps. Default is None + which means :code:`matplotlib.cm.Set1.colors`. + :type colormap: tuple of colors (3-tuple of float between 0. and 1.). :returns: A matplotlib object containing diagram plot of persistence (launch `show()` method on it to display it). """ @@ -221,7 +211,7 @@ def plot_persistence_diagram( import matplotlib.pyplot as plt import matplotlib.patches as mpatches - if persistence_file is not "": + if persistence_file != "": if path.isfile(persistence_file): # Reset persistence persistence = [] @@ -235,7 +225,7 @@ def plot_persistence_diagram( print("file " + persistence_file + " not found.") return None - if max_plots is not 1000: + if max_plots != 1000: print("Deprecated parameter. It has been replaced by max_intervals") max_intervals = max_plots @@ -247,6 +237,9 @@ def plot_persistence_diagram( reverse=True, )[:max_intervals] + if colormap == None: + colormap = plt.cm.Set1.colors + (min_birth, max_death) = __min_birth_max_death(persistence, band) delta = (max_death - min_birth) * inf_delta # Replace infinity values with max_death + delta for diagram to be more @@ -272,19 +265,19 @@ def plot_persistence_diagram( interval[1][0], interval[1][1], alpha=alpha, - color=palette[interval[0]], + color=colormap[interval[0]], ) else: # Infinite death case for diagram to be nicer plt.scatter( - interval[1][0], infinity, alpha=alpha, color=palette[interval[0]] + interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]] ) if legend: dimensions = list(set(item[0] for item in persistence)) plt.legend( handles=[ - mpatches.Patch(color=palette[dim], label=str(dim)) + mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions ] ) @@ -354,7 +347,7 @@ def plot_persistence_density( import matplotlib.pyplot as plt from scipy.stats import kde - if persistence_file is not "": + if persistence_file != "": if dimension is None: # All dimension case dimension = -1 -- cgit v1.2.3