From d4fbf78cf12488ecdf79f26ef6c05d6d1323704a Mon Sep 17 00:00:00 2001 From: Vincent Rouvreau Date: Fri, 4 Feb 2022 09:30:13 +0100 Subject: code review: rename dimensions as 'newshape' --- src/python/doc/cubical_complex_sklearn_itf_ref.rst | 6 +++--- src/python/gudhi/sklearn/cubical_persistence.py | 14 +++++++------- src/python/test/test_sklearn_cubical_persistence.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/python/doc/cubical_complex_sklearn_itf_ref.rst b/src/python/doc/cubical_complex_sklearn_itf_ref.rst index a57e5fbb..8248343b 100644 --- a/src/python/doc/cubical_complex_sklearn_itf_ref.rst +++ b/src/python/doc/cubical_complex_sklearn_itf_ref.rst @@ -54,9 +54,9 @@ two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected com X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0) pipe = Pipeline( [ - ("cub_pers", CubicalPersistence(persistence_dimension=0, dimensions=[28, 28], n_jobs=-2)), + ("cub_pers", CubicalPersistence(persistence_dimension=0, newshape=[28, 28], n_jobs=-2)), # Or for multiple persistence dimension computation - # ("cub_pers", CubicalPersistence(persistence_dimension=[0, 1], dimensions=[28, 28], n_jobs=-2)), + # ("cub_pers", CubicalPersistence(persistence_dimension=[0, 1], newshape=[28, 28], n_jobs=-2)), # ("H0_diags", DimensionSelector(index=0), # where index is the index in persistence_dimension array ("finite_diags", DiagramSelector(use=True, point_type="finite")), ( @@ -78,7 +78,7 @@ two holes in :math:`\mathbf{H}_1`, or, like in this example, three connected com There are 6825 eights out of 70000 numbers. Classification report for TDA pipeline Pipeline(steps=[('cub_pers', - CubicalPersistence(dimensions=[28, 28], n_jobs=-2)), + CubicalPersistence(newshape=[28, 28], n_jobs=-2)), ('finite_diags', DiagramSelector(use=True)), ('pers_img', PersistenceImage(bandwidth=50, im_range=[0, 256, 0, 256], diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py index 3997bc8a..ed56d2dd 100644 --- a/src/python/gudhi/sklearn/cubical_persistence.py +++ b/src/python/gudhi/sklearn/cubical_persistence.py @@ -32,7 +32,7 @@ class CubicalPersistence(BaseEstimator, TransformerMixin): def __init__( self, - dimensions=None, + newshape=None, persistence_dimension=-1, homology_coeff_field=11, min_persistence=0.0, @@ -42,7 +42,7 @@ class CubicalPersistence(BaseEstimator, TransformerMixin): Constructor for the CubicalPersistence class. Parameters: - dimensions (list of int): A list of number of top dimensional cells if cells filtration values will require + newshape (list of int): A list of number of top dimensional cells if cells filtration values will require to be reshaped (cf. :func:`~gudhi.sklearn.cubical_persistence.CubicalPersistence.transform`) persistence_dimension (int or list of int): The returned persistence diagrams dimension(s). Short circuit the use of :class:`~gudhi.representations.preprocessing.DimensionSelector` when only one @@ -52,7 +52,7 @@ class CubicalPersistence(BaseEstimator, TransformerMixin): `min_persistence`). Default value is `0.0`. Set `min_persistence` to `-1.0` to see all values. n_jobs (int): cf. https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html """ - self.dimensions = dimensions + self.newshape = newshape self.persistence_dimension = persistence_dimension self.homology_coeff_field = homology_coeff_field self.min_persistence = min_persistence @@ -65,7 +65,7 @@ class CubicalPersistence(BaseEstimator, TransformerMixin): return self def __transform(self, cells): - cubical_complex = CubicalComplex(top_dimensional_cells=cells, dimensions=self.dimensions) + cubical_complex = CubicalComplex(top_dimensional_cells=cells, dimensions=self.newshape) cubical_complex.compute_persistence( homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence ) @@ -74,7 +74,7 @@ class CubicalPersistence(BaseEstimator, TransformerMixin): ] def __transform_only_this_dim(self, cells): - cubical_complex = CubicalComplex(top_dimensional_cells=cells, dimensions=self.dimensions) + cubical_complex = CubicalComplex(top_dimensional_cells=cells, dimensions=self.newshape) cubical_complex.compute_persistence( homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence ) @@ -83,8 +83,8 @@ class CubicalPersistence(BaseEstimator, TransformerMixin): def transform(self, X, Y=None): """Compute all the cubical complexes and their associated persistence diagrams. - :param X: List of cells filtration values that should be flatten if `dimensions` is set in the constructor, or - already with the correct shape in a numpy.ndarray (and `dimensions` must not be set). + :param X: List of cells filtration values that should be flatten if `newshape` is set in the constructor, or + already with the correct shape in a numpy.ndarray (and `newshape` must not be set). :type X: list of list of float OR list of numpy.ndarray :return: Persistence diagrams in the format: diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py index 56c44db0..69b65dde 100644 --- a/src/python/test/test_sklearn_cubical_persistence.py +++ b/src/python/test/test_sklearn_cubical_persistence.py @@ -44,7 +44,7 @@ def test_simple_constructor_from_flattened_cells(): # Not squared (extended) flatten cells cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten() - cp = CubicalPersistence(persistence_dimension=0, dimensions=[10, 8]) + cp = CubicalPersistence(persistence_dimension=0, newshape=[10, 8]) diags = cp.fit_transform([cells]) np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0) -- cgit v1.2.3