summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/vector_methods.py
diff options
context:
space:
mode:
authorwreise <wojciech.reise@epfl.ch>2022-05-25 14:20:12 +0200
committerwreise <wojciech.reise@epfl.ch>2022-05-25 14:20:12 +0200
commit3aa89676d1dc2cafcc692480bbf424a97dbbd501 (patch)
tree6f227366694118be360ba44ca1ab0780f3e0b19c /src/python/gudhi/representations/vector_methods.py
parentf911ed3882c03b049a18f2a638758cb5ef994bcc (diff)
Vectorize Silhouette implementation
Diffstat (limited to 'src/python/gudhi/representations/vector_methods.py')
-rw-r--r--src/python/gudhi/representations/vector_methods.py48
1 files changed, 13 insertions, 35 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index f8078d03..62b35389 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -85,7 +85,7 @@ class PersistenceImage(BaseEstimator, TransformerMixin):
Xfit.append(image.flatten()[np.newaxis,:])
- Xfit = np.concatenate(Xfit,0)
+ Xfit = np.concatenate(Xfit, 0)
return Xfit
@@ -235,6 +235,7 @@ class Silhouette(BaseEstimator, TransformerMixin):
sample_range ([double, double]): minimum and maximum for the weighted average domain, of the form [x_min, x_max] (default [numpy.nan, numpy.nan]). It is the interval on which samples will be drawn evenly. If one of the values is numpy.nan, it can be computed from the persistence diagrams with the fit() method.
"""
self.weight, self.resolution, self.sample_range = weight, resolution, sample_range
+ self.im_range = None
def fit(self, X, y=None):
"""
@@ -245,6 +246,7 @@ class Silhouette(BaseEstimator, TransformerMixin):
y (n x 1 array): persistence diagram labels (unused).
"""
self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
+ self.im_range = np.linspace(self.sample_range[0], self.sample_range[1], self.resolution)
return self
def transform(self, X):
@@ -257,44 +259,20 @@ class Silhouette(BaseEstimator, TransformerMixin):
Returns:
numpy array with shape (number of diagrams) x (**resolution**): output persistence silhouettes.
"""
- num_diag, Xfit = len(X), []
- x_values = np.linspace(self.sample_range[0], self.sample_range[1], self.resolution)
- step_x = x_values[1] - x_values[0]
+ Xfit = []
+ x_values = self.im_range
- for i in range(num_diag):
-
- diagram, num_pts_in_diag = X[i], X[i].shape[0]
-
- sh, weights = np.zeros(self.resolution), np.zeros(num_pts_in_diag)
- for j in range(num_pts_in_diag):
- weights[j] = self.weight(diagram[j,:])
+ for i, diag in enumerate(X):
+ midpoints, heights = (diag[:, 0] + diag[:, 1])/2., (diag[:, 1] - diag[:, 0])/2.
+ weights = np.array([self.weight(point) for point in diag])
total_weight = np.sum(weights)
- for j in range(num_pts_in_diag):
-
- [px,py] = diagram[j,:2]
- weight = weights[j] / total_weight
- min_idx = np.clip(np.ceil((px - self.sample_range[0]) / step_x).astype(int), 0, self.resolution)
- mid_idx = np.clip(np.ceil((0.5*(py+px) - self.sample_range[0]) / step_x).astype(int), 0, self.resolution)
- max_idx = np.clip(np.ceil((py - self.sample_range[0]) / step_x).astype(int), 0, self.resolution)
-
- if min_idx < self.resolution and max_idx > 0:
-
- silhouette_value = self.sample_range[0] + min_idx * step_x - px
- for k in range(min_idx, mid_idx):
- sh[k] += weight * silhouette_value
- silhouette_value += step_x
-
- silhouette_value = py - self.sample_range[0] - mid_idx * step_x
- for k in range(mid_idx, max_idx):
- sh[k] += weight * silhouette_value
- silhouette_value -= step_x
-
- Xfit.append(np.reshape(np.sqrt(2) * sh, [1,-1]))
-
- Xfit = np.concatenate(Xfit, 0)
+ tent_functions = heights[None, :] - np.abs(x_values[:, None] - midpoints[None, :])
+ tent_functions[tent_functions < 0.] = 0.
+ silhouette = np.sum(weights[None, :]/total_weight * tent_functions, axis=1)
+ Xfit.append(silhouette * np.sqrt(2))
- return Xfit
+ return np.stack(Xfit, axis=0)
def __call__(self, diag):
"""