summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/vector_methods.py
diff options
context:
space:
mode:
authorVincent Rouvreau <vincent.rouvreau@inria.fr>2021-11-05 10:27:46 +0100
committerVincent Rouvreau <vincent.rouvreau@inria.fr>2021-11-05 10:27:46 +0100
commit3094e1fe51acc49e4ea7e4f38648bb25d96784a4 (patch)
treee1327baae54d36013485f048972939441035fb1f /src/python/gudhi/representations/vector_methods.py
parent7c26436a703a476d28cf568949275d26d1827c36 (diff)
code review: factorize sample range computation
Diffstat (limited to 'src/python/gudhi/representations/vector_methods.py')
-rw-r--r--src/python/gudhi/representations/vector_methods.py46
1 files changed, 26 insertions, 20 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index e7ee57a4..140162af 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -6,6 +6,7 @@
#
# Modification(s):
# - 2020/06 Martin: ATOL integration
+# - 2021/11 Vincent Rouvreau: factorize _automatic_sample_range
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
@@ -98,6 +99,23 @@ class PersistenceImage(BaseEstimator, TransformerMixin):
"""
return self.fit_transform([diag])[0,:]
+def _automatic_sample_range(sample_range, X, y):
+ """
+ Compute and returns sample range from the persistence diagrams if one of the sample_range values is numpy.nan.
+
+ Parameters:
+ sample_range (a numpy array of 2 float): minimum and maximum of all piecewise-linear function domains, of
+ the form [x_min, x_max].
+ X (list of n x 2 numpy arrays): input persistence diagrams.
+ y (n x 1 array): persistence diagram labels (unused).
+ """
+ nan_in_range = np.isnan(sample_range)
+ if nan_in_range.any():
+ pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y)
+ [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
+ return np.where(nan_in_range, np.array([mx, My]), sample_range)
+ return sample_range
+
class Landscape(BaseEstimator, TransformerMixin):
"""
This is a class for computing persistence landscapes from a list of persistence diagrams. A persistence landscape is a collection of 1D piecewise-linear functions computed from the rank function associated to the persistence diagram. These piecewise-linear functions are then sampled evenly on a given range and the corresponding vectors of samples are concatenated and returned. See http://jmlr.org/papers/v16/bubenik15a.html for more details.
@@ -123,14 +141,11 @@ class Landscape(BaseEstimator, TransformerMixin):
X (list of n x 2 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
- if self.nan_in_range.any():
- try:
- pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y)
- [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
- self.sample_range = np.where(self.nan_in_range, np.array([mx, My]), np.array(self.sample_range))
- except ValueError:
- # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
- pass
+ try:
+ self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
+ except ValueError:
+ # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
+ pass
return self
def transform(self, X):
@@ -227,10 +242,7 @@ class Silhouette(BaseEstimator, TransformerMixin):
y (n x 1 array): persistence diagram labels (unused).
"""
try:
- if np.isnan(np.array(self.sample_range)).any():
- pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y)
- [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
- self.sample_range = np.where(np.isnan(np.array(self.sample_range)), np.array([mx, My]), np.array(self.sample_range))
+ self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
except ValueError:
# Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
pass
@@ -320,10 +332,7 @@ class BettiCurve(BaseEstimator, TransformerMixin):
y (n x 1 array): persistence diagram labels (unused).
"""
try:
- if np.isnan(np.array(self.sample_range)).any():
- pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y)
- [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
- self.sample_range = np.where(np.isnan(np.array(self.sample_range)), np.array([mx, My]), np.array(self.sample_range))
+ self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
except ValueError:
# Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
pass
@@ -391,10 +400,7 @@ class Entropy(BaseEstimator, TransformerMixin):
y (n x 1 array): persistence diagram labels (unused).
"""
try:
- if np.isnan(np.array(self.sample_range)).any():
- pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y)
- [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
- self.sample_range = np.where(np.isnan(np.array(self.sample_range)), np.array([mx, My]), np.array(self.sample_range))
+ self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
except ValueError:
# Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
pass