summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/vector_methods.py
diff options
context:
space:
mode:
authorwreise <wojciech.reise@epfl.ch>2022-05-25 15:14:15 +0200
committerwreise <wojciech.reise@epfl.ch>2022-05-25 15:14:15 +0200
commite8d0cbc3311765900e098b472608dc40b84d07d8 (patch)
treef182b5dbea11f92694ebc1de7b8dc6961de10ccd /src/python/gudhi/representations/vector_methods.py
parent1a76ecc3e7459e3461e1f182004362dcb663addd (diff)
Optimize using keops
Diffstat (limited to 'src/python/gudhi/representations/vector_methods.py')
-rw-r--r--src/python/gudhi/representations/vector_methods.py27
1 files changed, 19 insertions, 8 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index e6289a37..55dc2c5b 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -10,6 +10,7 @@
# - 2021/11 Vincent Rouvreau: factorize _automatic_sample_range
import numpy as np
+from pykeops.numpy import Genred
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.exceptions import NotFittedError
from sklearn.preprocessing import MinMaxScaler, MaxAbsScaler
@@ -259,19 +260,29 @@ class Silhouette(BaseEstimator, TransformerMixin):
Returns:
numpy array with shape (number of diagrams) x (**resolution**): output persistence silhouettes.
"""
- Xfit = []
- x_values = self.im_range
+ silhouette_formula = "normalized_weights * ReLU(heights - Abs(x_values - midpoints))"
+ variables = [
+ "normalized_weights = Vi(1)",
+ "heights = Vi(1)",
+ "midpoints = Vi(1)",
+ "x_values = Vj(1)",
+ ]
+ silhouette = Genred(silhouette_formula, variables, reduction_op="Sum", axis=0)
+
+ silhouettes_list = []
+ x_values = self.im_range
for i, diag in enumerate(X):
- midpoints, heights = (diag[:, 0] + diag[:, 1])/2., (diag[:, 1] - diag[:, 0])/2.
+ midpoints, heights = (diag[:, 0] + diag[:, 1]) / 2., (diag[:, 1] - diag[:, 0]) / 2.
weights = np.array([self.weight(point) for point in diag])
- total_weight = np.sum(weights)
+ weights /= np.sum(weights)
- tent_functions = np.maximum(heights[None, :] - np.abs(x_values[:, None] - midpoints[None, :]), 0)
- silhouette = np.sum(weights[None, :]/total_weight * tent_functions, axis=1)
- Xfit.append(silhouette * np.sqrt(2))
+ silhouettes_list.append(
+ np.sqrt(2) * silhouette(weights[:, None], heights[:, None],
+ midpoints[:, None], x_values[:, None])[:, 0]
+ )
- return np.stack(Xfit, axis=0)
+ return np.stack(silhouettes_list, axis=0)
def __call__(self, diag):
"""