summaryrefslogtreecommitdiff
path: root/src/python/test
diff options
context:
space:
mode:
authorVincent Rouvreau <vincent.rouvreau@inria.fr>2022-06-20 17:07:51 +0200
committerVincent Rouvreau <vincent.rouvreau@inria.fr>2022-06-20 17:07:51 +0200
commit8d7244b510eee0d7927c521117198ef286028acf (patch)
tree4ff8da12f39e52f955e506e0e20a1e16b0aa8f8e /src/python/test
parent3bfc066548fbdbd9b1dc06b39d8ccecd2ce4d0b5 (diff)
code review: rename homology_dimensions argument. Use and document numpy reshape instead of cubical dimension argument
Diffstat (limited to 'src/python/test')
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
index 69b65dde..1c05a215 100644
--- a/src/python/test/test_sklearn_cubical_persistence.py
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -17,9 +17,9 @@ CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0.0, 6.0], [0.0, 8.0], [0.0, np.inf]])
def test_simple_constructor_from_top_cells():
cells = datasets.load_digits().images[0]
- cp = CubicalPersistence(persistence_dimension=0)
+ cp = CubicalPersistence(homology_dimensions=0)
np.testing.assert_array_equal(cp._CubicalPersistence__transform_only_this_dim(cells), CUBICAL_PERSISTENCE_H0_IMG0)
- cp = CubicalPersistence(persistence_dimension=[0, 2])
+ cp = CubicalPersistence(homology_dimensions=[0, 2])
diags = cp._CubicalPersistence__transform(cells)
assert len(diags) == 2
np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
@@ -27,13 +27,13 @@ def test_simple_constructor_from_top_cells():
def test_simple_constructor_from_top_cells_list():
digits = datasets.load_digits().images[:10]
- cp = CubicalPersistence(persistence_dimension=0, n_jobs=-2)
+ cp = CubicalPersistence(homology_dimensions=0, n_jobs=-2)
diags = cp.fit_transform(digits)
assert len(diags) == 10
np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
- cp = CubicalPersistence(persistence_dimension=[0, 1], n_jobs=-1)
+ cp = CubicalPersistence(homology_dimensions=[0, 1], n_jobs=-1)
diagsH0H1 = cp.fit_transform(digits)
assert len(diagsH0H1) == 10
for idx in range(10):
@@ -42,9 +42,18 @@ def test_simple_constructor_from_top_cells_list():
def test_simple_constructor_from_flattened_cells():
cells = datasets.load_digits().images[0]
# Not squared (extended) flatten cells
- cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten()
+ flat_cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten()
- cp = CubicalPersistence(persistence_dimension=0, newshape=[10, 8])
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
+ diags = cp.fit_transform([flat_cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+ # Not squared (extended) non-flatten cells
+ cells = np.hstack((cells, np.zeros((cells.shape[0], 2))))
+
+ # The aim of this second part of the test is to resize even if not mandatory
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
diags = cp.fit_transform([cells])
np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)