summaryrefslogtreecommitdiff
path: root/src/python/gudhi/persistence_graphical_tools.py
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2019-10-24 08:30:06 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2019-10-24 08:30:06 +0200
commit8227cd68d5aa7c9eeda5dd474f2536b896b6f491 (patch)
treecd36c9880d5a914ccaff4d37bfb22c0133a15af5 /src/python/gudhi/persistence_graphical_tools.py
parent0c6641b4e109f5116d6dad04bbab9bde0d56e347 (diff)
Fix issues #113 (replace 'is not' with '!=') and #109 (replace palette with a more visible one)
Diffstat (limited to 'src/python/gudhi/persistence_graphical_tools.py')
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py57
1 files changed, 25 insertions, 32 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 181bc8ea..a8e2051b 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -44,27 +44,6 @@ def __min_birth_max_death(persistence, band=0.0):
max_death += band
return (min_birth, max_death)
-
-"""
-Only 13 colors for the palette
-"""
-palette = [
- "#ff0000",
- "#00ff00",
- "#0000ff",
- "#00ffff",
- "#ff00ff",
- "#ffff00",
- "#000000",
- "#880000",
- "#008800",
- "#000088",
- "#888800",
- "#880088",
- "#008888",
-]
-
-
def plot_persistence_barcode(
persistence=[],
persistence_file="",
@@ -73,6 +52,7 @@ def plot_persistence_barcode(
max_barcodes=1000,
inf_delta=0.1,
legend=False,
+ colormap=None
):
"""This function plots the persistence bar code from persistence values list
or from a :doc:`persistence file <fileformats>`.
@@ -95,6 +75,9 @@ def plot_persistence_barcode(
:type inf_delta: float.
:param legend: Display the dimension color legend (default is False).
:type legend: boolean.
+ :param colormap: A matplotlib-like qualitative colormaps. Default is None
+ which means :code:`matplotlib.cm.Set1.colors`.
+ :type colormap: tuple of colors (3-tuple of float between 0. and 1.).
:returns: A matplotlib object containing horizontal bar plot of persistence
(launch `show()` method on it to display it).
"""
@@ -102,7 +85,7 @@ def plot_persistence_barcode(
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
- if persistence_file is not "":
+ if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
@@ -116,7 +99,7 @@ def plot_persistence_barcode(
print("file " + persistence_file + " not found.")
return None
- if max_barcodes is not 1000:
+ if max_barcodes != 1000:
print("Deprecated parameter. It has been replaced by max_intervals")
max_intervals = max_barcodes
@@ -127,6 +110,9 @@ def plot_persistence_barcode(
key=lambda life_time: life_time[1][1] - life_time[1][0],
reverse=True,
)[:max_intervals]
+
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
persistence = sorted(persistence, key=lambda birth: birth[1][0])
@@ -147,7 +133,7 @@ def plot_persistence_barcode(
height=0.8,
left=interval[1][0],
alpha=alpha,
- color=palette[interval[0]],
+ color=colormap[interval[0]],
linewidth=0,
)
else:
@@ -158,7 +144,7 @@ def plot_persistence_barcode(
height=0.8,
left=interval[1][0],
alpha=alpha,
- color=palette[interval[0]],
+ color=colormap[interval[0]],
linewidth=0,
)
ind = ind + 1
@@ -167,7 +153,7 @@ def plot_persistence_barcode(
dimensions = list(set(item[0] for item in persistence))
plt.legend(
handles=[
- mpatches.Patch(color=palette[dim], label=str(dim))
+ mpatches.Patch(color=colormap[dim], label=str(dim))
for dim in dimensions
],
loc="lower right",
@@ -190,6 +176,7 @@ def plot_persistence_diagram(
max_plots=1000,
inf_delta=0.1,
legend=False,
+ colormap=None
):
"""This function plots the persistence diagram from persistence values
list or from a :doc:`persistence file <fileformats>`.
@@ -214,6 +201,9 @@ def plot_persistence_diagram(
:type inf_delta: float.
:param legend: Display the dimension color legend (default is False).
:type legend: boolean.
+ :param colormap: A matplotlib-like qualitative colormaps. Default is None
+ which means :code:`matplotlib.cm.Set1.colors`.
+ :type colormap: tuple of colors (3-tuple of float between 0. and 1.).
:returns: A matplotlib object containing diagram plot of persistence
(launch `show()` method on it to display it).
"""
@@ -221,7 +211,7 @@ def plot_persistence_diagram(
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
- if persistence_file is not "":
+ if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
@@ -235,7 +225,7 @@ def plot_persistence_diagram(
print("file " + persistence_file + " not found.")
return None
- if max_plots is not 1000:
+ if max_plots != 1000:
print("Deprecated parameter. It has been replaced by max_intervals")
max_intervals = max_plots
@@ -247,6 +237,9 @@ def plot_persistence_diagram(
reverse=True,
)[:max_intervals]
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
+
(min_birth, max_death) = __min_birth_max_death(persistence, band)
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for diagram to be more
@@ -272,19 +265,19 @@ def plot_persistence_diagram(
interval[1][0],
interval[1][1],
alpha=alpha,
- color=palette[interval[0]],
+ color=colormap[interval[0]],
)
else:
# Infinite death case for diagram to be nicer
plt.scatter(
- interval[1][0], infinity, alpha=alpha, color=palette[interval[0]]
+ interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]
)
if legend:
dimensions = list(set(item[0] for item in persistence))
plt.legend(
handles=[
- mpatches.Patch(color=palette[dim], label=str(dim))
+ mpatches.Patch(color=colormap[dim], label=str(dim))
for dim in dimensions
]
)
@@ -354,7 +347,7 @@ def plot_persistence_density(
import matplotlib.pyplot as plt
from scipy.stats import kde
- if persistence_file is not "":
+ if persistence_file != "":
if dimension is None:
# All dimension case
dimension = -1