diff options
author | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-08-09 10:38:31 +0200 |
---|---|---|
committer | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-08-09 10:38:31 +0200 |
commit | 5c35605763273cb34efe4227b6d748992e99ab09 (patch) | |
tree | fe2c61809e63fe53165c2252639f53724a6ebb6d /src/python/test | |
parent | 91d72a69f2f04676fbd671af3dc2f3040c9f1c48 (diff) |
Make CubicalPersistence returns all dimensions. Post processing DimensionSelector can select the desired dimension
Diffstat (limited to 'src/python/test')
-rw-r--r-- | src/python/test/test_sklearn_cubical_persistence.py | 21 | ||||
-rw-r--r-- | src/python/test/test_sklearn_post_processing.py | 48 |
2 files changed, 65 insertions, 4 deletions
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py index c0082547..506985f1 100644 --- a/src/python/test/test_sklearn_cubical_persistence.py +++ b/src/python/test/test_sklearn_cubical_persistence.py @@ -16,17 +16,30 @@ __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2021 Inria" __license__ = "MIT" +CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0., 6.], [0., 8.], [ 0., np.inf]]) + def test_simple_constructor_from_top_cells(): cells = datasets.load_digits().images[0] - cp = CubicalPersistence(persistence_dim = 0) + cp = CubicalPersistence(only_this_dim = 0) np.testing.assert_array_equal(cp._CubicalPersistence__transform(cells), - np.array([[0., 6.], [0., 8.], [ 0., np.inf]])) + [CUBICAL_PERSISTENCE_H0_IMG0]) + cp = CubicalPersistence(max_persistence_dimension = 2) + diags = cp._CubicalPersistence__transform(cells) + assert len(diags) == 3 + np.testing.assert_array_equal(diags[0], + CUBICAL_PERSISTENCE_H0_IMG0) def test_simple_constructor_from_top_cells_list(): digits = datasets.load_digits().images[:10] - cp = CubicalPersistence(persistence_dim = 0, n_jobs=-2) + cp = CubicalPersistence(only_this_dim = 0, n_jobs=-2) diags = cp.fit_transform(digits) assert len(diags) == 10 np.testing.assert_array_equal(diags[0], - np.array([[0., 6.], [0., 8.], [ 0., np.inf]])) + CUBICAL_PERSISTENCE_H0_IMG0) + + cp = CubicalPersistence(max_persistence_dimension = 1, n_jobs=-1) + diagsH0H1 = cp.fit_transform(digits) + assert len(diagsH0H1) == 10 + for idx in range(10): + np.testing.assert_array_equal(diags[idx], diagsH0H1[idx][0]) diff --git a/src/python/test/test_sklearn_post_processing.py b/src/python/test/test_sklearn_post_processing.py new file mode 100644 index 00000000..3a251d34 --- /dev/null +++ b/src/python/test/test_sklearn_post_processing.py @@ -0,0 +1,48 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Vincent Rouvreau + + Copyright (C) 2021 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +from gudhi.sklearn.post_processing import DimensionSelector +import numpy as np +import pytest + +__author__ = "Vincent Rouvreau" +__copyright__ = "Copyright (C) 2021 Inria" +__license__ = "MIT" + +H0_0 = np.array([0., 0.]) +H1_0 = np.array([1., 0.]) +H0_1 = np.array([0., 1.]) +H1_1 = np.array([1., 1.]) +H0_2 = np.array([0., 2.]) +H1_2 = np.array([1., 2.]) + +def test_dimension_selector(): + X = [[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]] + ds = DimensionSelector(persistence_dimension = 0, n_jobs=-2) + h0 = ds.fit_transform(X) + np.testing.assert_array_equal(h0[0], + H0_0) + np.testing.assert_array_equal(h0[1], + H0_1) + np.testing.assert_array_equal(h0[2], + H0_2) + + ds = DimensionSelector(persistence_dimension = 1, n_jobs=-1) + h1 = ds.fit_transform(X) + np.testing.assert_array_equal(h1[0], + H1_0) + np.testing.assert_array_equal(h1[1], + H1_1) + np.testing.assert_array_equal(h1[2], + H1_2) + + ds = DimensionSelector(persistence_dimension = 2, n_jobs=-2) + with pytest.raises(IndexError): + h2 = ds.fit_transform([[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]]) |