diff options
Diffstat (limited to 'src/python/test')
-rw-r--r-- | src/python/test/test_representations_preprocessing.py (renamed from src/python/test/test_sklearn_post_processing.py) | 12 | ||||
-rw-r--r-- | src/python/test/test_sklearn_cubical_persistence.py | 16 |
2 files changed, 10 insertions, 18 deletions
diff --git a/src/python/test/test_sklearn_post_processing.py b/src/python/test/test_representations_preprocessing.py index e60eadc6..838cf30c 100644 --- a/src/python/test/test_sklearn_post_processing.py +++ b/src/python/test/test_representations_preprocessing.py @@ -8,14 +8,10 @@ - YYYY/MM Author: Description of the modification """ -from gudhi.sklearn.post_processing import DimensionSelector +from gudhi.representations.preprocessing import DimensionSelector import numpy as np import pytest -__author__ = "Vincent Rouvreau" -__copyright__ = "Copyright (C) 2021 Inria" -__license__ = "MIT" - H0_0 = np.array([0.0, 0.0]) H1_0 = np.array([1.0, 0.0]) H0_1 = np.array([0.0, 1.0]) @@ -26,18 +22,18 @@ H1_2 = np.array([1.0, 2.0]) def test_dimension_selector(): X = [[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]] - ds = DimensionSelector(persistence_dimension=0) + ds = DimensionSelector(index=0) h0 = ds.fit_transform(X) np.testing.assert_array_equal(h0[0], H0_0) np.testing.assert_array_equal(h0[1], H0_1) np.testing.assert_array_equal(h0[2], H0_2) - ds = DimensionSelector(persistence_dimension=1) + ds = DimensionSelector(index=1) h1 = ds.fit_transform(X) np.testing.assert_array_equal(h1[0], H1_0) np.testing.assert_array_equal(h1[1], H1_1) np.testing.assert_array_equal(h1[2], H1_2) - ds = DimensionSelector(persistence_dimension=2) + ds = DimensionSelector(index=2) with pytest.raises(IndexError): h2 = ds.fit_transform([[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]]) diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py index 488495d1..bd728a29 100644 --- a/src/python/test/test_sklearn_cubical_persistence.py +++ b/src/python/test/test_sklearn_cubical_persistence.py @@ -12,32 +12,28 @@ from gudhi.sklearn.cubical_persistence import CubicalPersistence import numpy as np from sklearn import datasets -__author__ = "Vincent Rouvreau" -__copyright__ = "Copyright (C) 2021 Inria" -__license__ = "MIT" - CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0.0, 6.0], [0.0, 8.0], [0.0, np.inf]]) def test_simple_constructor_from_top_cells(): cells = datasets.load_digits().images[0] - cp = CubicalPersistence(only_this_dim=0) - np.testing.assert_array_equal(cp._CubicalPersistence__transform(cells), [CUBICAL_PERSISTENCE_H0_IMG0]) - cp = CubicalPersistence(max_persistence_dimension=2) + cp = CubicalPersistence(persistence_dimension=0) + np.testing.assert_array_equal(cp._CubicalPersistence__transform_only_this_dim(cells), CUBICAL_PERSISTENCE_H0_IMG0) + cp = CubicalPersistence(persistence_dimension=[0, 2]) diags = cp._CubicalPersistence__transform(cells) - assert len(diags) == 3 + assert len(diags) == 2 np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0) def test_simple_constructor_from_top_cells_list(): digits = datasets.load_digits().images[:10] - cp = CubicalPersistence(only_this_dim=0, n_jobs=-2) + cp = CubicalPersistence(persistence_dimension=0, n_jobs=-2) diags = cp.fit_transform(digits) assert len(diags) == 10 np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0) - cp = CubicalPersistence(max_persistence_dimension=1, n_jobs=-1) + cp = CubicalPersistence(persistence_dimension=[0, 1], n_jobs=-1) diagsH0H1 = cp.fit_transform(digits) assert len(diagsH0H1) == 10 for idx in range(10): |