summaryrefslogtreecommitdiff
path: root/src/python/gudhi/persistence_graphical_tools.py
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-15 18:01:20 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-15 18:01:20 +0200
commit17da2842a3af3db3abec7c2ce03f77dad47ef5dc (patch)
tree8873c1b4afa6dd7938339c37f23f2f1b0ed6fad8 /src/python/gudhi/persistence_graphical_tools.py
parent845b02ff408eb50207165b8e11136e4b1888612a (diff)
Remove deprecated max_barcodes in plot_persistence_barcode and max_plots in plot_persistence_diagram. Fix #453 and factorize code
Diffstat (limited to 'src/python/gudhi/persistence_graphical_tools.py')
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py63
1 files changed, 29 insertions, 34 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 848dc03e..460e2558 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -12,6 +12,7 @@ from os import path
from math import isfinite
import numpy as np
from functools import lru_cache
+import warnings
from gudhi.reader_utils import read_persistence_intervals_in_dimension
from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension
@@ -58,6 +59,26 @@ def _array_handler(a):
else:
return a
+def _limit_to_max_intervals(persistence, max_intervals, key):
+ '''This function returns truncated persistence if length is bigger than max_intervals.
+ :param persistence: Persistence intervals values list. Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
+ :param max_intervals: maximal number of intervals to display.
+ Selected intervals are those with the longest life time. Set it
+ to 0 to see all. Default value is 1000.
+ :type max_intervals: int.
+ :param key: key function for sort algorithm.
+ :type key: function or lambda.
+ '''
+ if max_intervals > 0 and max_intervals < len(persistence):
+ warnings.warn('There are %s intervals given as input, whereas max_intervals is set to %s.' %
+ (len(persistence), max_intervals)
+ )
+ # Sort by life time, then takes only the max_intervals elements
+ return sorted(persistence, key = key, reverse = True)[:max_intervals]
+ else:
+ return persistence
+
@lru_cache(maxsize=1)
def _matplotlib_can_use_tex():
"""This function returns True if matplotlib can deal with LaTeX, False otherwise.
@@ -75,7 +96,6 @@ def plot_persistence_barcode(
persistence_file="",
alpha=0.6,
max_intervals=1000,
- max_barcodes=1000,
inf_delta=0.1,
legend=False,
colormap=None,
@@ -142,18 +162,9 @@ def plot_persistence_barcode(
persistence = _array_handler(persistence)
- if max_barcodes != 1000:
- print("Deprecated parameter. It has been replaced by max_intervals")
- max_intervals = max_barcodes
-
- if max_intervals > 0 and max_intervals < len(persistence):
- # Sort by life time, then takes only the max_intervals elements
- persistence = sorted(
- persistence,
- key=lambda life_time: life_time[1][1] - life_time[1][0],
- reverse=True,
- )[:max_intervals]
-
+ persistence = _limit_to_max_intervals(persistence, max_intervals,
+ key = lambda life_time: life_time[1][1] - life_time[1][0])
+
if colormap == None:
colormap = plt.cm.Set1.colors
if axes == None:
@@ -220,7 +231,6 @@ def plot_persistence_diagram(
alpha=0.6,
band=0.0,
max_intervals=1000,
- max_plots=1000,
inf_delta=0.1,
legend=False,
colormap=None,
@@ -291,17 +301,8 @@ def plot_persistence_diagram(
persistence = _array_handler(persistence)
- if max_plots != 1000:
- print("Deprecated parameter. It has been replaced by max_intervals")
- max_intervals = max_plots
-
- if max_intervals > 0 and max_intervals < len(persistence):
- # Sort by life time, then takes only the max_intervals elements
- persistence = sorted(
- persistence,
- key=lambda life_time: life_time[1][1] - life_time[1][0],
- reverse=True,
- )[:max_intervals]
+ persistence = _limit_to_max_intervals(persistence, max_intervals,
+ key = lambda life_time: life_time[1][1] - life_time[1][0])
if colormap == None:
colormap = plt.cm.Set1.colors
@@ -474,15 +475,9 @@ def plot_persistence_density(
)
persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
- if max_intervals > 0 and max_intervals < len(persistence_dim):
- # Sort by life time, then takes only the max_intervals elements
- persistence_dim = np.array(
- sorted(
- persistence_dim,
- key=lambda life_time: life_time[1] - life_time[0],
- reverse=True,
- )[:max_intervals]
- )
+
+ persistence_dim = np.array(_limit_to_max_intervals(persistence_dim, max_intervals,
+ key = lambda life_time: life_time[1] - life_time[0]))
# Set as numpy array birth and death (remove undefined values - inf and NaN)
birth = persistence_dim[:, 0]