diff options
author | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-08-10 09:33:18 +0200 |
---|---|---|
committer | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-08-10 09:33:18 +0200 |
commit | e4d2b1563640331835bd3e4c08ef2f650cd49db8 (patch) | |
tree | 9756bd5da6b058d263a671c2fb40d83159be47f0 | |
parent | 5c35605763273cb34efe4227b6d748992e99ab09 (diff) |
black files
-rw-r--r-- | src/python/gudhi/sklearn/cubical_persistence.py | 18 | ||||
-rw-r--r-- | src/python/test/test_sklearn_cubical_persistence.py | 21 | ||||
-rw-r--r-- | src/python/test/test_sklearn_post_processing.py | 41 |
3 files changed, 43 insertions, 37 deletions
diff --git a/src/python/gudhi/sklearn/cubical_persistence.py b/src/python/gudhi/sklearn/cubical_persistence.py index 7b77000d..329c9435 100644 --- a/src/python/gudhi/sklearn/cubical_persistence.py +++ b/src/python/gudhi/sklearn/cubical_persistence.py @@ -30,7 +30,15 @@ class CubicalPersistence(BaseEstimator, TransformerMixin): This is a class for computing the persistence diagrams from a cubical complex. """ - def __init__(self, dimensions=None, max_persistence_dimension=0, only_this_dim=-1, homology_coeff_field=11, min_persistence=0., n_jobs=None): + def __init__( + self, + dimensions=None, + max_persistence_dimension=0, + only_this_dim=-1, + homology_coeff_field=11, + min_persistence=0.0, + n_jobs=None, + ): """ Constructor for the CubicalPersistence class. @@ -66,7 +74,9 @@ class CubicalPersistence(BaseEstimator, TransformerMixin): cubical_complex.compute_persistence( homology_coeff_field=self.homology_coeff_field, min_persistence=self.min_persistence ) - return [cubical_complex.persistence_intervals_in_dimension(dim) for dim in range(self.max_persistence_dimension + 1)] + return [ + cubical_complex.persistence_intervals_in_dimension(dim) for dim in range(self.max_persistence_dimension + 1) + ] def __transform_only_this_dim(self, cells): cubical_complex = CubicalComplex(top_dimensional_cells=cells, dimensions=self.dimensions) @@ -95,4 +105,6 @@ class CubicalPersistence(BaseEstimator, TransformerMixin): return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(cells) for cells in X) else: # threads is preferred as cubical construction and persistence computation releases the GIL - return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform_only_this_dim)(cells) for cells in X) + return Parallel(n_jobs=self.n_jobs, prefer="threads")( + delayed(self.__transform_only_this_dim)(cells) for cells in X + ) diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py index 506985f1..488495d1 100644 --- a/src/python/test/test_sklearn_cubical_persistence.py +++ b/src/python/test/test_sklearn_cubical_persistence.py @@ -16,29 +16,28 @@ __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2021 Inria" __license__ = "MIT" -CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0., 6.], [0., 8.], [ 0., np.inf]]) +CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0.0, 6.0], [0.0, 8.0], [0.0, np.inf]]) + def test_simple_constructor_from_top_cells(): cells = datasets.load_digits().images[0] - cp = CubicalPersistence(only_this_dim = 0) - np.testing.assert_array_equal(cp._CubicalPersistence__transform(cells), - [CUBICAL_PERSISTENCE_H0_IMG0]) - cp = CubicalPersistence(max_persistence_dimension = 2) + cp = CubicalPersistence(only_this_dim=0) + np.testing.assert_array_equal(cp._CubicalPersistence__transform(cells), [CUBICAL_PERSISTENCE_H0_IMG0]) + cp = CubicalPersistence(max_persistence_dimension=2) diags = cp._CubicalPersistence__transform(cells) assert len(diags) == 3 - np.testing.assert_array_equal(diags[0], - CUBICAL_PERSISTENCE_H0_IMG0) + np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0) + def test_simple_constructor_from_top_cells_list(): digits = datasets.load_digits().images[:10] - cp = CubicalPersistence(only_this_dim = 0, n_jobs=-2) + cp = CubicalPersistence(only_this_dim=0, n_jobs=-2) diags = cp.fit_transform(digits) assert len(diags) == 10 - np.testing.assert_array_equal(diags[0], - CUBICAL_PERSISTENCE_H0_IMG0) + np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0) - cp = CubicalPersistence(max_persistence_dimension = 1, n_jobs=-1) + cp = CubicalPersistence(max_persistence_dimension=1, n_jobs=-1) diagsH0H1 = cp.fit_transform(digits) assert len(diagsH0H1) == 10 for idx in range(10): diff --git a/src/python/test/test_sklearn_post_processing.py b/src/python/test/test_sklearn_post_processing.py index 3a251d34..60bf8162 100644 --- a/src/python/test/test_sklearn_post_processing.py +++ b/src/python/test/test_sklearn_post_processing.py @@ -16,33 +16,28 @@ __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2021 Inria" __license__ = "MIT" -H0_0 = np.array([0., 0.]) -H1_0 = np.array([1., 0.]) -H0_1 = np.array([0., 1.]) -H1_1 = np.array([1., 1.]) -H0_2 = np.array([0., 2.]) -H1_2 = np.array([1., 2.]) +H0_0 = np.array([0.0, 0.0]) +H1_0 = np.array([1.0, 0.0]) +H0_1 = np.array([0.0, 1.0]) +H1_1 = np.array([1.0, 1.0]) +H0_2 = np.array([0.0, 2.0]) +H1_2 = np.array([1.0, 2.0]) + def test_dimension_selector(): X = [[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]] - ds = DimensionSelector(persistence_dimension = 0, n_jobs=-2) + ds = DimensionSelector(persistence_dimension=0, n_jobs=-2) h0 = ds.fit_transform(X) - np.testing.assert_array_equal(h0[0], - H0_0) - np.testing.assert_array_equal(h0[1], - H0_1) - np.testing.assert_array_equal(h0[2], - H0_2) - - ds = DimensionSelector(persistence_dimension = 1, n_jobs=-1) + np.testing.assert_array_equal(h0[0], H0_0) + np.testing.assert_array_equal(h0[1], H0_1) + np.testing.assert_array_equal(h0[2], H0_2) + + ds = DimensionSelector(persistence_dimension=1, n_jobs=-1) h1 = ds.fit_transform(X) - np.testing.assert_array_equal(h1[0], - H1_0) - np.testing.assert_array_equal(h1[1], - H1_1) - np.testing.assert_array_equal(h1[2], - H1_2) - - ds = DimensionSelector(persistence_dimension = 2, n_jobs=-2) + np.testing.assert_array_equal(h1[0], H1_0) + np.testing.assert_array_equal(h1[1], H1_1) + np.testing.assert_array_equal(h1[2], H1_2) + + ds = DimensionSelector(persistence_dimension=2, n_jobs=-2) with pytest.raises(IndexError): h2 = ds.fit_transform([[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]]) |