summaryrefslogtreecommitdiff
path: root/src/cython/cython/persistence_graphical_tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/cython/cython/persistence_graphical_tools.py')
-rw-r--r--src/cython/cython/persistence_graphical_tools.py55
1 files changed, 25 insertions, 30 deletions
diff --git a/src/cython/cython/persistence_graphical_tools.py b/src/cython/cython/persistence_graphical_tools.py
index d7be936f..7bb69840 100644
--- a/src/cython/cython/persistence_graphical_tools.py
+++ b/src/cython/cython/persistence_graphical_tools.py
@@ -1,10 +1,14 @@
+from os import path
+from math import isfinite
+import numpy as np
+
"""This file is part of the Gudhi Library. The Gudhi library
(Geometric Understanding in Higher Dimensions) is a generic C++
library for computational topology.
Author(s): Vincent Rouvreau, Bertrand Michel
- Copyright (C) 2016 Inria
+ Copyright (C) 2019 Inria
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
@@ -85,11 +89,9 @@ def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6,
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
- import numpy as np
- import os
if persistence_file is not '':
- if os.path.isfile(persistence_file):
+ if path.isfile(persistence_file):
# Reset persistence
persistence = []
diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
@@ -144,7 +146,7 @@ def plot_persistence_barcode(persistence=[], persistence_file='', alpha=0.6,
return plt
except ImportError:
- print("This function is not available, you may be missing numpy and/or matplotlib.")
+ print("This function is not available, you may be missing matplotlib.")
def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6,
band=0., max_intervals=1000, max_plots=1000, inf_delta=0.1, legend=False):
@@ -177,11 +179,9 @@ def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6,
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
- import numpy as np
- import os
if persistence_file is not '':
- if os.path.isfile(persistence_file):
+ if path.isfile(persistence_file):
# Reset persistence
persistence = []
diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
@@ -240,7 +240,7 @@ def plot_persistence_diagram(persistence=[], persistence_file='', alpha=0.6,
return plt
except ImportError:
- print("This function is not available, you may be missing numpy and/or matplotlib.")
+ print("This function is not available, you may be missing matplotlib.")
def plot_persistence_density(persistence=[], persistence_file='',
nbins=300, bw_method=None,
@@ -288,38 +288,33 @@ def plot_persistence_density(persistence=[], persistence_file='',
"""
try:
import matplotlib.pyplot as plt
- import numpy as np
from scipy.stats import kde
- import os
- import math
if persistence_file is not '':
- if os.path.isfile(persistence_file):
- # Reset persistence
- persistence = []
- diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
- for key in diag.keys():
- for persistence_interval in diag[key]:
- persistence.append((key, persistence_interval))
+ if dimension is None:
+ # All dimension case
+ dimension = -1
+ if path.isfile(persistence_file):
+ persistence_dim = read_persistence_intervals_in_dimension(persistence_file=persistence_file,
+ only_this_dim=dimension)
+ print(persistence_dim)
else:
print("file " + persistence_file + " not found.")
return None
- persistence_dim = []
- if dimension is not None:
- persistence_dim = [(dim_interval) for dim_interval in persistence if (dim_interval[0] == dimension)]
- else:
- persistence_dim = persistence
+ if len(persistence) > 0:
+ persistence_dim = np.array([(dim_interval[1][0], dim_interval[1][1]) for dim_interval in persistence if (dim_interval[0] == dimension) or (dimension is None)])
+ persistence_dim = persistence_dim[np.isfinite(persistence_dim[:,1])]
if max_intervals > 0 and max_intervals < len(persistence_dim):
# Sort by life time, then takes only the max_intervals elements
- persistence_dim = sorted(persistence_dim,
- key=lambda life_time: life_time[1][1]-life_time[1][0],
- reverse=True)[:max_intervals]
+ persistence_dim = np.array(sorted(persistence_dim,
+ key=lambda life_time: life_time[1]-life_time[0],
+ reverse=True)[:max_intervals])
# Set as numpy array birth and death (remove undefined values - inf and NaN)
- birth = np.asarray([(interval[1][0]) for interval in persistence_dim if (math.isfinite(interval[1][1]) and math.isfinite(interval[1][0]))])
- death = np.asarray([(interval[1][1]) for interval in persistence_dim if (math.isfinite(interval[1][1]) and math.isfinite(interval[1][0]))])
+ birth = persistence_dim[:,0]
+ death = persistence_dim[:,1]
# line display of equation : birth = death
x = np.linspace(death.min(), birth.max(), 1000)
@@ -345,4 +340,4 @@ def plot_persistence_density(persistence=[], persistence_file='',
return plt
except ImportError:
- print("This function is not available, you may be missing numpy, matplotlib and/or scipy.")
+ print("This function is not available, you may be missing matplotlib and/or scipy.")