diff options
Diffstat (limited to 'src/cython/cython/persistence_graphical_tools.py')
-rw-r--r-- | src/cython/cython/persistence_graphical_tools.py | 55 |
1 files changed, 25 insertions, 30 deletions
diff --git a/src/cython/cython/persistence_graphical_tools.py b/src/cython/cython/persistence_graphical_tools.py index d7be936f..7bb69840 100644 --- a/src/cython/cython/persistence_graphical_tools.py +++ b/src/cython/cython/persistence_graphical_tools.py @@ -1,10 +1,14 @@ +from os import path +from math import isfinite +import numpy as np + """This file is part of the Gudhi Library. The Gudhi library (Geometric Understanding in Higher Dimensions) is a generic C++ library for computational topology. Author(s): Vincent Rouvreau, Bertrand Michel - Copyright (C) 2016 Inria + Copyright (C) 2019 Inria This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -85,11 +89,9 @@ def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6, try: import matplotlib.pyplot as plt import matplotlib.patches as mpatches - import numpy as np - import os if persistence_file is not '': - if os.path.isfile(persistence_file): + if path.isfile(persistence_file): # Reset persistence persistence = [] diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) @@ -144,7 +146,7 @@ def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6, return plt except ImportError: - print("This function is not available, you may be missing numpy and/or matplotlib.") + print("This function is not available, you may be missing matplotlib.") def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6, band=0., max_intervals=1000, max_plots=1000, inf_delta=0.1, legend=False): @@ -177,11 +179,9 @@ def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6, try: import matplotlib.pyplot as plt import matplotlib.patches as mpatches - import numpy as np - import os if persistence_file is not '': - if os.path.isfile(persistence_file): + if path.isfile(persistence_file): # Reset persistence persistence = [] diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) @@ -240,7 +240,7 @@ def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6, return plt except ImportError: - print("This function is not available, you may be missing numpy and/or matplotlib.") + print("This function is not available, you may be missing matplotlib.") def plot_persistence_density(persistence=[], persistence_file='', nbins=300, bw_method=None, @@ -288,38 +288,33 @@ def plot_persistence_density(persistence=[], persistence_file='', """ try: import matplotlib.pyplot as plt - import numpy as np from scipy.stats import kde - import os - import math if persistence_file is not '': - if os.path.isfile(persistence_file): - # Reset persistence - persistence = [] - diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) - for key in diag.keys(): - for persistence_interval in diag[key]: - persistence.append((key, persistence_interval)) + if dimension is None: + # All dimension case + dimension = -1 + if path.isfile(persistence_file): + persistence_dim = read_persistence_intervals_in_dimension(persistence_file=persistence_file, + only_this_dim=dimension) + print(persistence_dim) else: print("file " + persistence_file + " not found.") return None - persistence_dim = [] - if dimension is not None: - persistence_dim = [(dim_interval) for dim_interval in persistence if (dim_interval[0] == dimension)] - else: - persistence_dim = persistence + if len(persistence) > 0: + persistence_dim = np.array([(dim_interval[1][0], dim_interval[1][1]) for dim_interval in persistence if (dim_interval[0] == dimension) or (dimension is None)]) + persistence_dim = persistence_dim[np.isfinite(persistence_dim[:,1])] if max_intervals > 0 and max_intervals < len(persistence_dim): # Sort by life time, then takes only the max_intervals elements - persistence_dim = sorted(persistence_dim, - key=lambda life_time: life_time[1][1]-life_time[1][0], - reverse=True)[:max_intervals] + persistence_dim = np.array(sorted(persistence_dim, + key=lambda life_time: life_time[1]-life_time[0], + reverse=True)[:max_intervals]) # Set as numpy array birth and death (remove undefined values - inf and NaN) - birth = np.asarray([(interval[1][0]) for interval in persistence_dim if (math.isfinite(interval[1][1]) and math.isfinite(interval[1][0]))]) - death = np.asarray([(interval[1][1]) for interval in persistence_dim if (math.isfinite(interval[1][1]) and math.isfinite(interval[1][0]))]) + birth = persistence_dim[:,0] + death = persistence_dim[:,1] # line display of equation : birth = death x = np.linspace(death.min(), birth.max(), 1000) @@ -345,4 +340,4 @@ def plot_persistence_density(persistence=[], persistence_file='', return plt except ImportError: - print("This function is not available, you may be missing numpy, matplotlib and/or scipy.") + print("This function is not available, you may be missing matplotlib and/or scipy.") |