From 37d7743a91f7fb970425a06798ac6cb61b0be109 Mon Sep 17 00:00:00 2001 From: Vincent Rouvreau Date: Fri, 5 Nov 2021 12:05:45 +0100 Subject: code review: try/except in function and assert on length of diagrams for error menagement --- src/python/gudhi/representations/vector_methods.py | 38 +++++++++------------- 1 file changed, 15 insertions(+), 23 deletions(-) (limited to 'src/python') diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py index 140162af..e883b5dd 100644 --- a/src/python/gudhi/representations/vector_methods.py +++ b/src/python/gudhi/representations/vector_methods.py @@ -111,9 +111,14 @@ def _automatic_sample_range(sample_range, X, y): """ nan_in_range = np.isnan(sample_range) if nan_in_range.any(): - pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y) - [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]] - return np.where(nan_in_range, np.array([mx, My]), sample_range) + try: + pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y) + [mx,my] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]] + [Mx,My] = [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]] + return np.where(nan_in_range, np.array([mx, My]), sample_range) + except ValueError: + # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507 + pass return sample_range class Landscape(BaseEstimator, TransformerMixin): @@ -141,11 +146,7 @@ class Landscape(BaseEstimator, TransformerMixin): X (list of n x 2 numpy arrays): input persistence diagrams. y (n x 1 array): persistence diagram labels (unused). """ - try: - self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y) - except ValueError: - # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507 - pass + self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y) return self def transform(self, X): @@ -241,11 +242,7 @@ class Silhouette(BaseEstimator, TransformerMixin): X (list of n x 2 numpy arrays): input persistence diagrams. y (n x 1 array): persistence diagram labels (unused). """ - try: - self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y) - except ValueError: - # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507 - pass + self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y) return self def transform(self, X): @@ -331,11 +328,7 @@ class BettiCurve(BaseEstimator, TransformerMixin): X (list of n x 2 numpy arrays): input persistence diagrams. y (n x 1 array): persistence diagram labels (unused). """ - try: - self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y) - except ValueError: - # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507 - pass + self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y) return self def transform(self, X): @@ -399,11 +392,7 @@ class Entropy(BaseEstimator, TransformerMixin): X (list of n x 2 numpy arrays): input persistence diagrams. y (n x 1 array): persistence diagram labels (unused). """ - try: - self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y) - except ValueError: - # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507 - pass + self.sample_range = _automatic_sample_range(np.array(self.sample_range), X, y) return self def transform(self, X): @@ -427,6 +416,7 @@ class Entropy(BaseEstimator, TransformerMixin): new_diagram = DiagramScaler(use=True, scalers=[([1], MaxAbsScaler())]).fit_transform([diagram])[0] except ValueError: # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507 + assert len(diagram) == 0 new_diagram = np.empty(shape = [0, 2]) if self.mode == "scalar": @@ -510,6 +500,8 @@ class TopologicalVector(BaseEstimator, TransformerMixin): try: distances = DistanceMetric.get_metric("chebyshev").pairwise(diagram) except ValueError: + # Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507 + assert len(diagram) == 0 distances = np.empty(shape = [0, 0]) vect = np.flip(np.sort(np.triu(np.minimum(distances, min_pers)), axis=None), 0) dim = min(len(vect), thresh) -- cgit v1.2.3