summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/vector_methods.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/representations/vector_methods.py')
-rw-r--r--src/python/gudhi/representations/vector_methods.py152
1 files changed, 151 insertions, 1 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index cdcb1fde..fda0a22d 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -1,14 +1,16 @@
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
-# Author(s): Mathieu Carrière, Martin Royer
+# Author(s): Mathieu Carrière, Martin Royer, Gard Spreemann
#
# Copyright (C) 2018-2020 Inria
#
# Modification(s):
# - 2020/06 Martin: ATOL integration
+# - 2020/12 Gard: A more flexible Betti curve class capable of computing exact curves.
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
+from sklearn.exceptions import NotFittedError
from sklearn.preprocessing import MinMaxScaler, MaxAbsScaler
from sklearn.neighbors import DistanceMetric
from sklearn.metrics import pairwise
@@ -350,6 +352,154 @@ class BettiCurve(BaseEstimator, TransformerMixin):
"""
return self.fit_transform([diag])[0,:]
+
+class BettiCurve2(BaseEstimator, TransformerMixin):
+ """
+ A more flexible replacement for the BettiCurve class.
+
+ Examples
+ --------
+ If pd is a persistence diagram and xs is a grid such that xs[0] >= pd.min(), then the result of
+ >>> bc = BettiCurve2(xs)
+ >>> result = bc(pd)
+ and
+ >>> from scipy.interpolate import interp1d
+ >>> bc = BettiCurve2(None)
+ >>> bettis = bc.fit_transform([pd])
+ >>> interp = interp1d(bc.grid_, bettis[0, :], kind="previous", fill_value="extrapolate")
+ >>> result = np.array(interp(xs), dtype=int)
+ are the same.
+ """
+
+ def __init__(self, grid = None):
+ """
+ Constructor for the BettiCurve class.
+
+ Parameters
+ ----------
+ grid: 1d array or None, default=None
+ Filtration grid points at which to compute the Betti curves. Must be strictly ordered. Infinites are OK. If None (default), a grid will be computed that captures all the filtration value changes.
+
+ Attributes
+ ----------
+ grid_: 1d array
+ Contains the compute grid after fit or fit_transform.
+ """
+
+ self.grid_ = np.array(grid)
+
+
+ def fit(self, X, y = None):
+ """
+ Compute a filtration grid that captures all changes in Betti numbers for all the given persistence diagrams.
+
+ Parameters
+ ----------
+ X: list of 2d arrays
+ Persistence diagrams.
+
+ y: None.
+ Ignored.
+ """
+
+ events = np.unique(np.concatenate([pd.flatten() for pd in X], axis=0))
+
+ if len(events) == 0:
+ self.grid_ = np.array([-np.inf])
+ else:
+ self.grid_ = np.array(events)
+
+ return self
+
+
+ def fit_transform(self, X):
+ """
+ Find a sampling grid that captures all changes in Betti numbers, and compute those Betti numbers. The result is the same as fit(X) followed by transform(X), but potentially faster.
+ """
+
+ N = len(X)
+
+ events = np.concatenate([pd.flatten(order="F") for pd in X], axis=0)
+ sorting = np.argsort(events)
+ offsets = np.zeros(1 + N, dtype=int)
+ for i in range(0, N):
+ offsets[i+1] = offsets[i] + 2*X[i].shape[0]
+ starts = offsets[0:N]
+ ends = offsets[1:N + 1] - 1
+
+ bettis = [[0] for i in range(0, N)]
+ if len(sorting) == 0:
+ xs = [-np.inf]
+ else:
+ xs = [events[sorting[0]]]
+
+ for i in sorting:
+ j = np.searchsorted(ends, i)
+ delta = 1 if i - starts[j] < len(X[j]) else -1
+ if events[i] == xs[-1]:
+ bettis[j][-1] += delta
+ else:
+ xs.append(events[i])
+ for k in range(0, j):
+ bettis[k].append(bettis[k][-1])
+ bettis[j].append(bettis[j][-1] + delta)
+ for k in range(j+1, N):
+ bettis[k].append(bettis[k][-1])
+
+ self.grid_ = np.array(xs)
+ return np.array(bettis, dtype=int)
+
+
+ def transform(self, X):
+ """
+ Compute Betti curves.
+
+ Parameters
+ ----------
+ X: list of 2d arrays
+ Persistence diagrams.
+
+ Returns
+ -------
+ (len(X))x(len(self.grid_)) array of ints
+ Betti numbers of the given persistence diagrams at the grid points given in self.grid_.
+ """
+
+ if self.grid_ is None:
+ raise NotFittedError("Not fitted. You need to call fit or construct with a chosen sampling grid.")
+
+ N = len(X)
+
+ events = np.concatenate([pd.flatten(order="F") for pd in X], axis=0)
+ sorting = np.argsort(events)
+ offsets = np.zeros(1 + N, dtype=int)
+ for i in range(0, N):
+ offsets[i+1] = offsets[i] + 2*X[i].shape[0]
+ starts = offsets[0:N]
+ ends = offsets[1:N + 1] - 1
+
+ bettis = [[0] for i in range(0, N)]
+
+ i = 0
+ for x in self.grid_:
+ while i < len(sorting) and events[sorting[i]] <= x:
+ j = np.searchsorted(ends, sorting[i])
+ delta = 1 if sorting[i] - starts[j] < len(X[j]) else -1
+ bettis[j][-1] += delta
+ i += 1
+ for k in range(0, N):
+ bettis[k].append(bettis[k][-1])
+
+ return np.array(bettis, dtype=int)[:, 0:-1]
+
+
+ def __call__(self, diag):
+ """
+ Shorthand for transform on a single persistence diagram.
+ """
+ return self.transform([diag])[0, :]
+
+
class Entropy(BaseEstimator, TransformerMixin):
"""
This is a class for computing persistence entropy. Persistence entropy is a statistic for persistence diagrams inspired from Shannon entropy. This statistic can also be used to compute a feature vector, called the entropy summary function. See https://arxiv.org/pdf/1803.08304.pdf for more details. Note that a previous implementation was contributed by Manuel Soriano-Trigueros.