summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2021-02-28 18:40:46 +0100
committerGard Spreemann <gspr@nonempty.org>2021-02-28 21:50:48 +0100
commitfddeb5724fe2e7f1f37476c5e3cfade992a4edec (patch)
tree41f77294950c7a3cdb793c0403e3be9c84ae4976
parentccb63b32bc65c0a6030dfab0b70ece62d9eff988 (diff)
Behave in line with scikit-learn guidelines
According to [1], we should in particular not do any validation in the constructor, and fit/fit_transform should always update underscored attributes (self.grid_ in this case). We still want to allow for a user-defined, data-independent grid, so we make this a separate parameter predefined_grid. [1] https://scikit-learn.org/stable/developers/develop.html
-rw-r--r--src/python/gudhi/representations/vector_methods.py86
1 files changed, 46 insertions, 40 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index 13630360..62a467c0 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -359,13 +359,13 @@ class BettiCurve2(BaseEstimator, TransformerMixin):
Parameters
----------
- grid: 1d array or None, default=None
- Filtration grid points at which to compute the Betti curves. Must be strictly ordered. Infinites are OK. If None (default), a grid will be computed that captures all the filtration value changes.
+ predefined_grid: 1d array or None, default=None
+ Predefined filtration grid points at which to compute the Betti curves. Must be strictly ordered. Infinities are OK. If None (default), a grid will be computed that captures all changes in Betti numbers in the provided data.
Attributes
----------
grid_: 1d array
- Contains the compute grid after fit or fit_transform.
+ The grid on which the Betti numbers are computed. If predefined_grid was specified, grid_ will always be that grid, independently of data. If not, the grid is fitted to capture all filtration values at which the Betti numbers change.
Examples
--------
@@ -381,13 +381,17 @@ class BettiCurve2(BaseEstimator, TransformerMixin):
are the same.
"""
- def __init__(self, grid = None):
- self.grid_ = np.array(grid)
+ def __init__(self, predefined_grid = None):
+ self.predefined_grid = predefined_grid
+
+
+ def is_fitted(self):
+ return hasattr(self, "grid_")
def fit(self, X, y = None):
"""
- Compute a filtration grid that captures all changes in Betti numbers for all the given persistence diagrams.
+ Compute a filtration grid that captures all changes in Betti numbers for all the given persistence diagrams, unless a predefined grid was provided.
Parameters
----------
@@ -398,12 +402,11 @@ class BettiCurve2(BaseEstimator, TransformerMixin):
Ignored.
"""
- events = np.unique(np.concatenate([pd.flatten() for pd in X], axis=0))
-
- if len(events) == 0:
- self.grid_ = np.array([-np.inf])
- else:
+ if self.predefined_grid is None:
+ events = np.unique(np.concatenate([pd.flatten() for pd in X] + [[-np.inf]], axis=0))
self.grid_ = np.array(events)
+ else:
+ self.grid_ = np.array(self.predefined_grid)
return self
@@ -413,37 +416,39 @@ class BettiCurve2(BaseEstimator, TransformerMixin):
Find a sampling grid that captures all changes in Betti numbers, and compute those Betti numbers. The result is the same as fit(X) followed by transform(X), but potentially faster.
"""
- N = len(X)
+ if self.predefined_grid is None:
+ N = len(X)
- events = np.concatenate([pd.flatten(order="F") for pd in X], axis=0)
- sorting = np.argsort(events)
- offsets = np.zeros(1 + N, dtype=int)
- for i in range(0, N):
- offsets[i+1] = offsets[i] + 2*X[i].shape[0]
- starts = offsets[0:N]
- ends = offsets[1:N + 1] - 1
+ events = np.concatenate([pd.flatten(order="F") for pd in X], axis=0)
+ sorting = np.argsort(events)
+ offsets = np.zeros(1 + N, dtype=int)
+ for i in range(0, N):
+ offsets[i+1] = offsets[i] + 2*X[i].shape[0]
+ starts = offsets[0:N]
+ ends = offsets[1:N + 1] - 1
- bettis = [[0] for i in range(0, N)]
- if len(sorting) == 0:
xs = [-np.inf]
- else:
- xs = [events[sorting[0]]]
+ bettis = [[0] for i in range(0, N)]
+
+ for i in sorting:
+ j = np.searchsorted(ends, i)
+ delta = 1 if i - starts[j] < len(X[j]) else -1
+ if events[i] == xs[-1]:
+ bettis[j][-1] += delta
+ else:
+ xs.append(events[i])
+ for k in range(0, j):
+ bettis[k].append(bettis[k][-1])
+ bettis[j].append(bettis[j][-1] + delta)
+ for k in range(j+1, N):
+ bettis[k].append(bettis[k][-1])
+
+ self.grid_ = np.array(xs)
+ return np.array(bettis, dtype=int)
- for i in sorting:
- j = np.searchsorted(ends, i)
- delta = 1 if i - starts[j] < len(X[j]) else -1
- if events[i] == xs[-1]:
- bettis[j][-1] += delta
- else:
- xs.append(events[i])
- for k in range(0, j):
- bettis[k].append(bettis[k][-1])
- bettis[j].append(bettis[j][-1] + delta)
- for k in range(j+1, N):
- bettis[k].append(bettis[k][-1])
-
- self.grid_ = np.array(xs)
- return np.array(bettis, dtype=int)
+ else:
+ self.grid_ = self.predefined_grid
+ return self.transform(X)
def transform(self, X):
@@ -461,8 +466,8 @@ class BettiCurve2(BaseEstimator, TransformerMixin):
Betti numbers of the given persistence diagrams at the grid points given in self.grid_.
"""
- if self.grid_ is None:
- raise NotFittedError("Not fitted. You need to call fit or construct with a chosen sampling grid.")
+ if not self.is_fitted():
+ raise NotFittedError("Not fitted.")
N = len(X)
@@ -496,6 +501,7 @@ class BettiCurve2(BaseEstimator, TransformerMixin):
return self.transform([diag])[0, :]
+
class Entropy(BaseEstimator, TransformerMixin):
"""
This is a class for computing persistence entropy. Persistence entropy is a statistic for persistence diagrams inspired from Shannon entropy. This statistic can also be used to compute a feature vector, called the entropy summary function. See https://arxiv.org/pdf/1803.08304.pdf for more details. Note that a previous implementation was contributed by Manuel Soriano-Trigueros.