summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/vector_methods.py
diff options
context:
space:
mode:
authorVincent Rouvreau <vincent.rouvreau@inria.fr>2021-11-05 12:05:45 +0100
committerVincent Rouvreau <vincent.rouvreau@inria.fr>2021-11-05 12:05:45 +0100
commit37d7743a91f7fb970425a06798ac6cb61b0be109 (patch)
tree3e22e8271c2c56b9a51fb02fb4d2b9cf5f3fe43c /src/python/gudhi/representations/vector_methods.py
parent3094e1fe51acc49e4ea7e4f38648bb25d96784a4 (diff)
code review: try/except in function and assert on length of diagrams for error menagement
Diffstat (limited to 'src/python/gudhi/representations/vector_methods.py')
-rw-r--r--src/python/gudhi/representations/vector_methods.py38
1 files changed, 15 insertions, 23 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index 140162af..e883b5dd 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -111,9 +111,14 @@ def _automatic_sample_range(sample_range, X, y):
"""
nan_in_range = np.isnan(sample_range)
if nan_in_range.any():
- pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y)
- [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
- return np.where(nan_in_range, np.array([mx, My]), sample_range)
+ try:
+ pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y)
+ [mx,my] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]]
+ [Mx,My] = [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
+ return np.where(nan_in_range, np.array([mx, My]), sample_range)
+ except ValueError:
+ # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
+ pass
return sample_range
class Landscape(BaseEstimator, TransformerMixin):
@@ -141,11 +146,7 @@ class Landscape(BaseEstimator, TransformerMixin):
X (list of n x 2 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
- try:
- self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
- except ValueError:
- # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
- pass
+ self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
return self
def transform(self, X):
@@ -241,11 +242,7 @@ class Silhouette(BaseEstimator, TransformerMixin):
X (list of n x 2 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
- try:
- self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
- except ValueError:
- # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
- pass
+ self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
return self
def transform(self, X):
@@ -331,11 +328,7 @@ class BettiCurve(BaseEstimator, TransformerMixin):
X (list of n x 2 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
- try:
- self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
- except ValueError:
- # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
- pass
+ self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
return self
def transform(self, X):
@@ -399,11 +392,7 @@ class Entropy(BaseEstimator, TransformerMixin):
X (list of n x 2 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
- try:
- self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
- except ValueError:
- # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
- pass
+ self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y)
return self
def transform(self, X):
@@ -427,6 +416,7 @@ class Entropy(BaseEstimator, TransformerMixin):
new_diagram = DiagramScaler(use=True, scalers=[([1], MaxAbsScaler())]).fit_transform([diagram])[0]
except ValueError:
# Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
+ assert len(diagram) == 0
new_diagram = np.empty(shape = [0, 2])
if self.mode == "scalar":
@@ -510,6 +500,8 @@ class TopologicalVector(BaseEstimator, TransformerMixin):
try:
distances = DistanceMetric.get_metric("chebyshev").pairwise(diagram)
except ValueError:
+ # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
+ assert len(diagram) == 0
distances = np.empty(shape = [0, 0])
vect = np.flip(np.sort(np.triu(np.minimum(distances, min_pers)), axis=None), 0)
dim = min(len(vect), thresh)