summaryrefslogtreecommitdiff
path: root/src/cython
diff options
context:
space:
mode:
authorvrouvrea <vrouvrea@636b058d-ea47-450e-bf9e-a15bfbe3eedb>2018-08-17 09:38:55 +0000
committervrouvrea <vrouvrea@636b058d-ea47-450e-bf9e-a15bfbe3eedb>2018-08-17 09:38:55 +0000
commitb7354cd2ab0b265787348866da5d08cce15aee66 (patch)
tree7f87a6e2854d8b5668af87bd1a2ceffbd8d1e19b /src/cython
parentefab7ccf5850ccc58018b7f856d209c4094e67c5 (diff)
Code review : add a parameter to select dimension for plot_persistence_density. Default is None and mix all dimensions
git-svn-id: svn+ssh://scm.gforge.inria.fr/svnroot/gudhi/branches/plot_persistence_density_vincent@3794 636b058d-ea47-450e-bf9e-a15bfbe3eedb Former-commit-id: 493f5ecfad2bcf0ee4a40ec7b315863c471d088f
Diffstat (limited to 'src/cython')
-rwxr-xr-xsrc/cython/cython/persistence_graphical_tools.py43
1 files changed, 28 insertions, 15 deletions
diff --git a/src/cython/cython/persistence_graphical_tools.py b/src/cython/cython/persistence_graphical_tools.py
index 14c4cf3f..949f5c37 100755
--- a/src/cython/cython/persistence_graphical_tools.py
+++ b/src/cython/cython/persistence_graphical_tools.py
@@ -221,26 +221,31 @@ try:
from scipy.stats import kde
import math
- def plot_persistence_density(persistence=[], persistence_file='', nbins=300,
- max_plots=1000, cmap=plt.cm.hot_r, legend=False):
- """This function plots the persistence density from persistence values list
- or from a :doc:`persistence file <fileformats>`. Be aware that this
- function does not distinguish the dimension, it is up to you to select the
- required one.
+ def plot_persistence_density(persistence=[], persistence_file='',
+ nbins=300, max_plots=1000, dimension=None,
+ cmap=plt.cm.hot_r, legend=False):
+ """This function plots the persistence density from persistence
+ values list or from a :doc:`persistence file <fileformats>`. Be
+ aware that this function does not distinguish the dimension, it is
+ up to you to select the required one.
:param persistence: Persistence values list.
:type persistence: list of tuples(dimension, tuple(birth, death)).
- :param persistence_file: A :doc:`persistence file <fileformats>` style name
- (reset persistence if both are set).
+ :param persistence_file: A :doc:`persistence file <fileformats>`
+ style name (reset persistence if both are set).
:type persistence_file: string
- :param nbins: Evaluate a gaussian kde on a regular grid of nbins x nbins
- over data extents (default is 300)
+ :param nbins: Evaluate a gaussian kde on a regular grid of nbins x
+ nbins over data extents (default is 300)
:type nbins: int.
:param max_plots: number of maximal plots to be displayed
Set it to 0 to see all, Default value is 1000.
(persistence will be sorted by life time if max_plots is set)
:type max_plots: int.
- :param cmap: A matplotlib colormap (default is matplotlib.pyplot.cm.hot_r).
+ :param dimension: the dimension to be selected in the intervals
+ (default is None to mix all dimensions).
+ :type dimension: int.
+ :param cmap: A matplotlib colormap (default is
+ matplotlib.pyplot.cm.hot_r).
:type cmap: cf. matplotlib colormap.
:param legend: Display the color bar values (default is False).
:type legend: boolean.
@@ -259,13 +264,21 @@ try:
print("file " + persistence_file + " not found.")
return None
- if max_plots > 0 and max_plots < len(persistence):
+ persistence_dim = []
+ if dimension is not None:
+ persistence_dim = [(dim_interval) for dim_interval in persistence if (dim_interval[0] == dimension)]
+ else:
+ persistence_dim = persistence
+
+ if max_plots > 0 and max_plots < len(persistence_dim):
# Sort by life time, then takes only the max_plots elements
- persistence = sorted(persistence, key=lambda life_time: life_time[1][1]-life_time[1][0], reverse=True)[:max_plots]
+ persistence_dim = sorted(persistence_dim,
+ key=lambda life_time: life_time[1][1]-life_time[1][0],
+ reverse=True)[:max_plots]
# Set as numpy array birth and death (remove undefined values - inf and NaN)
- birth = np.asarray([(interval[1][0]) for interval in persistence if (math.isfinite(interval[1][1]) and math.isfinite(interval[1][0]))])
- death = np.asarray([(interval[1][1]) for interval in persistence if (math.isfinite(interval[1][1]) and math.isfinite(interval[1][0]))])
+ birth = np.asarray([(interval[1][0]) for interval in persistence_dim if (math.isfinite(interval[1][1]) and math.isfinite(interval[1][0]))])
+ death = np.asarray([(interval[1][1]) for interval in persistence_dim if (math.isfinite(interval[1][1]) and math.isfinite(interval[1][0]))])
# line display of equation : birth = death
x = np.linspace(death.min(), birth.max(), 1000)