summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2022-11-04 14:05:42 +0100
committerMarc Glisse <marc.glisse@inria.fr>2022-11-04 14:05:42 +0100
commit2ebdeb905d3ca90e2ba2d24e6d3aac52240f6c86 (patch)
tree5096de5f1eb6b52bef10cab78cb5332bd13ced9a
parent7595644442365412cf7afce56eafea85342da07f (diff)
More consistent choice of a grid for diagram representations
-rw-r--r--src/python/gudhi/representations/vector_methods.py46
-rwxr-xr-xsrc/python/test/test_representations.py12
2 files changed, 44 insertions, 14 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index a169aee8..212fa9f5 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -124,7 +124,7 @@ def _automatic_sample_range(sample_range, X, y):
return sample_range
-def _trim_on_edges(x, are_endpoints_nan):
+def _trim_endpoints(x, are_endpoints_nan):
if are_endpoints_nan[0]:
x = x[1:]
if are_endpoints_nan[1]:
@@ -136,7 +136,7 @@ class Landscape(BaseEstimator, TransformerMixin):
"""
This is a class for computing persistence landscapes from a list of persistence diagrams. A persistence landscape is a collection of 1D piecewise-linear functions computed from the rank function associated to the persistence diagram. These piecewise-linear functions are then sampled evenly on a given range and the corresponding vectors of samples are concatenated and returned. See http://jmlr.org/papers/v16/bubenik15a.html for more details.
"""
- def __init__(self, num_landscapes=5, resolution=100, sample_range=[np.nan, np.nan]):
+ def __init__(self, num_landscapes=5, resolution=100, sample_range=[np.nan, np.nan], *, keep_endpoints=False):
"""
Constructor for the Landscape class.
@@ -144,10 +144,14 @@ class Landscape(BaseEstimator, TransformerMixin):
num_landscapes (int): number of piecewise-linear functions to output (default 5).
resolution (int): number of sample for all piecewise-linear functions (default 100).
sample_range ([double, double]): minimum and maximum of all piecewise-linear function domains, of the form [x_min, x_max] (default [numpy.nan, numpy.nan]). It is the interval on which samples will be drawn evenly. If one of the values is numpy.nan, it can be computed from the persistence diagrams with the fit() method.
+ keep_endpoints (bool): when guessing `sample_range`, use the exact extremities (where the value is always 0). This is mostly useful for plotting, the default is to use a slightly smaller range.
"""
self.num_landscapes, self.resolution, self.sample_range = num_landscapes, resolution, sample_range
self.nan_in_range = np.isnan(np.array(self.sample_range))
- self.new_resolution = self.resolution + self.nan_in_range.sum()
+ self.new_resolution = self.resolution
+ if not keep_endpoints:
+ self.new_resolution += self.nan_in_range.sum()
+ self.keep_endpoints = keep_endpoints
def fit(self, X, y=None):
"""
@@ -158,8 +162,9 @@ class Landscape(BaseEstimator, TransformerMixin):
y (n x 1 array): persistence diagram labels (unused).
"""
self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
- self.im_range = np.linspace(self.sample_range[0], self.sample_range[1], self.new_resolution)
- self.im_range = _trim_on_edges(self.im_range, self.nan_in_range)
+ self.grid_ = np.linspace(self.sample_range[0], self.sample_range[1], self.new_resolution)
+ if not self.keep_endpoints:
+ self.grid_ = _trim_endpoints(self.grid_, self.nan_in_range)
return self
def transform(self, X):
@@ -174,7 +179,7 @@ class Landscape(BaseEstimator, TransformerMixin):
"""
Xfit = []
- x_values = self.im_range
+ x_values = self.grid_
for diag in X:
midpoints, heights = (diag[:, 0] + diag[:, 1]) / 2., (diag[:, 1] - diag[:, 0]) / 2.
tent_functions = np.maximum(heights[None, :] - np.abs(x_values[:, None] - midpoints[None, :]), 0)
@@ -209,7 +214,7 @@ class Silhouette(BaseEstimator, TransformerMixin):
"""
This is a class for computing persistence silhouettes from a list of persistence diagrams. A persistence silhouette is computed by taking a weighted average of the collection of 1D piecewise-linear functions given by the persistence landscapes, and then by evenly sampling this average on a given range. Finally, the corresponding vector of samples is returned. See https://arxiv.org/abs/1312.0308 for more details.
"""
- def __init__(self, weight=lambda x: 1, resolution=100, sample_range=[np.nan, np.nan]):
+ def __init__(self, weight=lambda x: 1, resolution=100, sample_range=[np.nan, np.nan], *, keep_endpoints=False):
"""
Constructor for the Silhouette class.
@@ -217,10 +222,14 @@ class Silhouette(BaseEstimator, TransformerMixin):
weight (function): weight function for the persistence diagram points (default constant function, ie lambda x: 1). This function must be defined on 2D points, ie on lists or numpy arrays of the form [p_x,p_y].
resolution (int): number of samples for the weighted average (default 100).
sample_range ([double, double]): minimum and maximum for the weighted average domain, of the form [x_min, x_max] (default [numpy.nan, numpy.nan]). It is the interval on which samples will be drawn evenly. If one of the values is numpy.nan, it can be computed from the persistence diagrams with the fit() method.
+ keep_endpoints (bool): when guessing `sample_range`, use the exact extremities (where the value is always 0). This is mostly useful for plotting, the default is to use a slightly smaller range.
"""
self.weight, self.resolution, self.sample_range = weight, resolution, sample_range
self.nan_in_range = np.isnan(np.array(self.sample_range))
- self.new_resolution = self.resolution + self.nan_in_range.sum()
+ self.new_resolution = self.resolution
+ if not keep_endpoints:
+ self.new_resolution += self.nan_in_range.sum()
+ self.keep_endpoints = keep_endpoints
def fit(self, X, y=None):
"""
@@ -231,8 +240,9 @@ class Silhouette(BaseEstimator, TransformerMixin):
y (n x 1 array): persistence diagram labels (unused).
"""
self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
- self.im_range = np.linspace(self.sample_range[0], self.sample_range[1], self.new_resolution)
- self.im_range = _trim_on_edges(self.im_range, self.nan_in_range)
+ self.grid_ = np.linspace(self.sample_range[0], self.sample_range[1], self.new_resolution)
+ if not self.keep_endpoints:
+ self.grid_ = _trim_endpoints(self.grid_, self.nan_in_range)
return self
def transform(self, X):
@@ -246,7 +256,7 @@ class Silhouette(BaseEstimator, TransformerMixin):
numpy array with shape (number of diagrams) x (**resolution**): output persistence silhouettes.
"""
Xfit = []
- x_values = self.im_range
+ x_values = self.grid_
for diag in X:
midpoints, heights = (diag[:, 0] + diag[:, 1]) / 2., (diag[:, 1] - diag[:, 0]) / 2.
@@ -277,14 +287,15 @@ class BettiCurve(BaseEstimator, TransformerMixin):
Compute Betti curves from persistence diagrams. There are several modes of operation: with a given resolution (with or without a sample_range), with a predefined grid, and with none of the previous. With a predefined grid, the class computes the Betti numbers at those grid points. Without a predefined grid, if the resolution is set to None, it can be fit to a list of persistence diagrams and produce a grid that consists of (at least) the filtration values at which at least one of those persistence diagrams changes Betti numbers, and then compute the Betti numbers at those grid points. In the latter mode, the exact Betti curve is computed for the entire real line. Otherwise, if the resolution is given, the Betti curve is obtained by sampling evenly using either the given sample_range or based on the persistence diagrams.
"""
- def __init__(self, resolution=100, sample_range=[np.nan, np.nan], predefined_grid=None):
+ def __init__(self, resolution=100, sample_range=[np.nan, np.nan], predefined_grid=None, *, keep_endpoints=False):
"""
Constructor for the BettiCurve class.
Parameters:
- resolution (int): number of sample for the piecewise-constant function (default 100).
+ resolution (int): number of samples for the piecewise-constant function (default 100), or None for the exact curve.
sample_range ([double, double]): minimum and maximum of the piecewise-constant function domain, of the form [x_min, x_max] (default [numpy.nan, numpy.nan]). It is the interval on which samples will be drawn evenly. If one of the values is numpy.nan, it can be computed from the persistence diagrams with the fit() method.
predefined_grid (1d array or None, default=None): Predefined filtration grid points at which to compute the Betti curves. Must be strictly ordered. Infinities are ok. If None (default), and resolution is given, the grid will be uniform from x_min to x_max in 'resolution' steps, otherwise a grid will be computed that captures all changes in Betti numbers in the provided data.
+ keep_endpoints (bool): when guessing `sample_range` (fixed `resolution`, no `predefined_grid`), use the exact extremities. This is mostly useful for plotting, the default is to use a slightly smaller range.
Attributes:
grid_ (1d array): The grid on which the Betti numbers are computed. If predefined_grid was specified, `grid_` will always be that grid, independently of data. If not, the grid is fitted to capture all filtration values at which the Betti numbers change.
@@ -313,6 +324,7 @@ class BettiCurve(BaseEstimator, TransformerMixin):
self.predefined_grid = predefined_grid
self.resolution = resolution
self.sample_range = sample_range
+ self.keep_endpoints = keep_endpoints
def is_fitted(self):
return hasattr(self, "grid_")
@@ -331,8 +343,14 @@ class BettiCurve(BaseEstimator, TransformerMixin):
events = np.unique(np.concatenate([pd.flatten() for pd in X] + [[-np.inf]], axis=0))
self.grid_ = np.array(events)
else:
+ self.nan_in_range = np.isnan(np.array(self.sample_range))
+ self.new_resolution = self.resolution
+ if not self.keep_endpoints:
+ self.new_resolution += self.nan_in_range.sum()
self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
- self.grid_ = np.linspace(self.sample_range[0], self.sample_range[1], self.resolution)
+ self.grid_ = np.linspace(self.sample_range[0], self.sample_range[1], self.new_resolution)
+ if not self.keep_endpoints:
+ self.grid_ = _trim_endpoints(self.grid_, self.nan_in_range)
else:
self.grid_ = self.predefined_grid # Get the predefined grid from user
diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py
index 58caab21..9e94feeb 100755
--- a/src/python/test/test_representations.py
+++ b/src/python/test/test_representations.py
@@ -251,3 +251,15 @@ def test_landscape_nan_range():
lds_dgm = lds(dgm)
assert (lds.sample_range[0] == 2) & (lds.sample_range[1] == 6)
assert lds.new_resolution == 10
+
+def test_endpoints():
+ diags = [ np.array([[2., 3.]]) ]
+ for vec in [ Landscape(), Silhouette(), BettiCurve() ]:
+ vec.fit(diags)
+ assert vec.grid_[0] > 2 and vec.grid_[-1] < 3
+ for vec in [ Landscape(keep_endpoints=True), Silhouette(keep_endpoints=True), BettiCurve(keep_endpoints=True) ]:
+ vec.fit(diags)
+ assert vec.grid_[0] == 2 and vec.grid_[-1] == 3
+ vec = BettiCurve(resolution=None)
+ vec.fit(diags)
+ assert np.equal(vec.grid_, [-np.inf, 2., 3.]).all()