diff options
author | wreise <wojciech.reise@epfl.ch> | 2022-05-25 14:20:12 +0200 |
---|---|---|
committer | wreise <wojciech.reise@epfl.ch> | 2022-05-25 14:20:12 +0200 |
commit | 3aa89676d1dc2cafcc692480bbf424a97dbbd501 (patch) | |
tree | 6f227366694118be360ba44ca1ab0780f3e0b19c /src/python/gudhi/representations/vector_methods.py | |
parent | f911ed3882c03b049a18f2a638758cb5ef994bcc (diff) |
Vectorize Silhouette implementation
Diffstat (limited to 'src/python/gudhi/representations/vector_methods.py')
-rw-r--r-- | src/python/gudhi/representations/vector_methods.py | 48 |
1 files changed, 13 insertions, 35 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py index f8078d03..62b35389 100644 --- a/src/python/gudhi/representations/vector_methods.py +++ b/src/python/gudhi/representations/vector_methods.py @@ -85,7 +85,7 @@ class PersistenceImage(BaseEstimator, TransformerMixin): Xfit.append(image.flatten()[np.newaxis,:]) - Xfit = np.concatenate(Xfit,0) + Xfit = np.concatenate(Xfit, 0) return Xfit @@ -235,6 +235,7 @@ class Silhouette(BaseEstimator, TransformerMixin): sample_range ([double, double]): minimum and maximum for the weighted average domain, of the form [x_min, x_max] (default [numpy.nan, numpy.nan]). It is the interval on which samples will be drawn evenly. If one of the values is numpy.nan, it can be computed from the persistence diagrams with the fit() method. """ self.weight, self.resolution, self.sample_range = weight, resolution, sample_range + self.im_range = None def fit(self, X, y=None): """ @@ -245,6 +246,7 @@ class Silhouette(BaseEstimator, TransformerMixin): y (n x 1 array): persistence diagram labels (unused). """ self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y) + self.im_range = np.linspace(self.sample_range[0], self.sample_range[1], self.resolution) return self def transform(self, X): @@ -257,44 +259,20 @@ class Silhouette(BaseEstimator, TransformerMixin): Returns: numpy array with shape (number of diagrams) x (**resolution**): output persistence silhouettes. """ - num_diag, Xfit = len(X), [] - x_values = np.linspace(self.sample_range[0], self.sample_range[1], self.resolution) - step_x = x_values[1] - x_values[0] + Xfit = [] + x_values = self.im_range - for i in range(num_diag): - - diagram, num_pts_in_diag = X[i], X[i].shape[0] - - sh, weights = np.zeros(self.resolution), np.zeros(num_pts_in_diag) - for j in range(num_pts_in_diag): - weights[j] = self.weight(diagram[j,:]) + for i, diag in enumerate(X): + midpoints, heights = (diag[:, 0] + diag[:, 1])/2., (diag[:, 1] - diag[:, 0])/2. + weights = np.array([self.weight(point) for point in diag]) total_weight = np.sum(weights) - for j in range(num_pts_in_diag): - - [px,py] = diagram[j,:2] - weight = weights[j] / total_weight - min_idx = np.clip(np.ceil((px - self.sample_range[0]) / step_x).astype(int), 0, self.resolution) - mid_idx = np.clip(np.ceil((0.5*(py+px) - self.sample_range[0]) / step_x).astype(int), 0, self.resolution) - max_idx = np.clip(np.ceil((py - self.sample_range[0]) / step_x).astype(int), 0, self.resolution) - - if min_idx < self.resolution and max_idx > 0: - - silhouette_value = self.sample_range[0] + min_idx * step_x - px - for k in range(min_idx, mid_idx): - sh[k] += weight * silhouette_value - silhouette_value += step_x - - silhouette_value = py - self.sample_range[0] - mid_idx * step_x - for k in range(mid_idx, max_idx): - sh[k] += weight * silhouette_value - silhouette_value -= step_x - - Xfit.append(np.reshape(np.sqrt(2) * sh, [1,-1])) - - Xfit = np.concatenate(Xfit, 0) + tent_functions = heights[None, :] - np.abs(x_values[:, None] - midpoints[None, :]) + tent_functions[tent_functions < 0.] = 0. + silhouette = np.sum(weights[None, :]/total_weight * tent_functions, axis=1) + Xfit.append(silhouette * np.sqrt(2)) - return Xfit + return np.stack(Xfit, axis=0) def __call__(self, diag): """ |