diff options
author | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-06-15 18:01:20 +0200 |
---|---|---|
committer | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-06-15 18:01:20 +0200 |
commit | 17da2842a3af3db3abec7c2ce03f77dad47ef5dc (patch) | |
tree | 8873c1b4afa6dd7938339c37f23f2f1b0ed6fad8 /src/python/gudhi/persistence_graphical_tools.py | |
parent | 845b02ff408eb50207165b8e11136e4b1888612a (diff) |
Remove deprecated max_barcodes in plot_persistence_barcode and max_plots in plot_persistence_diagram. Fix #453 and factorize code
Diffstat (limited to 'src/python/gudhi/persistence_graphical_tools.py')
-rw-r--r-- | src/python/gudhi/persistence_graphical_tools.py | 63 |
1 files changed, 29 insertions, 34 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 848dc03e..460e2558 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -12,6 +12,7 @@ from os import path from math import isfinite import numpy as np from functools import lru_cache +import warnings from gudhi.reader_utils import read_persistence_intervals_in_dimension from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension @@ -58,6 +59,26 @@ def _array_handler(a): else: return a +def _limit_to_max_intervals(persistence, max_intervals, key): + '''This function returns truncated persistence if length is bigger than max_intervals. + :param persistence: Persistence intervals values list. Can be grouped by dimension or not. + :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death). + :param max_intervals: maximal number of intervals to display. + Selected intervals are those with the longest life time. Set it + to 0 to see all. Default value is 1000. + :type max_intervals: int. + :param key: key function for sort algorithm. + :type key: function or lambda. + ''' + if max_intervals > 0 and max_intervals < len(persistence): + warnings.warn('There are %s intervals given as input, whereas max_intervals is set to %s.' % + (len(persistence), max_intervals) + ) + # Sort by life time, then takes only the max_intervals elements + return sorted(persistence, key = key, reverse = True)[:max_intervals] + else: + return persistence + @lru_cache(maxsize=1) def _matplotlib_can_use_tex(): """This function returns True if matplotlib can deal with LaTeX, False otherwise. @@ -75,7 +96,6 @@ def plot_persistence_barcode( persistence_file="", alpha=0.6, max_intervals=1000, - max_barcodes=1000, inf_delta=0.1, legend=False, colormap=None, @@ -142,18 +162,9 @@ def plot_persistence_barcode( persistence = _array_handler(persistence) - if max_barcodes != 1000: - print("Deprecated parameter. It has been replaced by max_intervals") - max_intervals = max_barcodes - - if max_intervals > 0 and max_intervals < len(persistence): - # Sort by life time, then takes only the max_intervals elements - persistence = sorted( - persistence, - key=lambda life_time: life_time[1][1] - life_time[1][0], - reverse=True, - )[:max_intervals] - + persistence = _limit_to_max_intervals(persistence, max_intervals, + key = lambda life_time: life_time[1][1] - life_time[1][0]) + if colormap == None: colormap = plt.cm.Set1.colors if axes == None: @@ -220,7 +231,6 @@ def plot_persistence_diagram( alpha=0.6, band=0.0, max_intervals=1000, - max_plots=1000, inf_delta=0.1, legend=False, colormap=None, @@ -291,17 +301,8 @@ def plot_persistence_diagram( persistence = _array_handler(persistence) - if max_plots != 1000: - print("Deprecated parameter. It has been replaced by max_intervals") - max_intervals = max_plots - - if max_intervals > 0 and max_intervals < len(persistence): - # Sort by life time, then takes only the max_intervals elements - persistence = sorted( - persistence, - key=lambda life_time: life_time[1][1] - life_time[1][0], - reverse=True, - )[:max_intervals] + persistence = _limit_to_max_intervals(persistence, max_intervals, + key = lambda life_time: life_time[1][1] - life_time[1][0]) if colormap == None: colormap = plt.cm.Set1.colors @@ -474,15 +475,9 @@ def plot_persistence_density( ) persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])] - if max_intervals > 0 and max_intervals < len(persistence_dim): - # Sort by life time, then takes only the max_intervals elements - persistence_dim = np.array( - sorted( - persistence_dim, - key=lambda life_time: life_time[1] - life_time[0], - reverse=True, - )[:max_intervals] - ) + + persistence_dim = np.array(_limit_to_max_intervals(persistence_dim, max_intervals, + key = lambda life_time: life_time[1] - life_time[0])) # Set as numpy array birth and death (remove undefined values - inf and NaN) birth = persistence_dim[:, 0] |