summaryrefslogtreecommitdiff
path: root/src/python/test/test_sklearn_cubical_persistence.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test/test_sklearn_cubical_persistence.py')
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
index 69b65dde..1c05a215 100644
--- a/src/python/test/test_sklearn_cubical_persistence.py
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -17,9 +17,9 @@ CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0.0, 6.0], [0.0, 8.0], [0.0, np.inf]])
def test_simple_constructor_from_top_cells():
cells = datasets.load_digits().images[0]
- cp = CubicalPersistence(persistence_dimension=0)
+ cp = CubicalPersistence(homology_dimensions=0)
np.testing.assert_array_equal(cp._CubicalPersistence__transform_only_this_dim(cells), CUBICAL_PERSISTENCE_H0_IMG0)
- cp = CubicalPersistence(persistence_dimension=[0, 2])
+ cp = CubicalPersistence(homology_dimensions=[0, 2])
diags = cp._CubicalPersistence__transform(cells)
assert len(diags) == 2
np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
@@ -27,13 +27,13 @@ def test_simple_constructor_from_top_cells():
def test_simple_constructor_from_top_cells_list():
digits = datasets.load_digits().images[:10]
- cp = CubicalPersistence(persistence_dimension=0, n_jobs=-2)
+ cp = CubicalPersistence(homology_dimensions=0, n_jobs=-2)
diags = cp.fit_transform(digits)
assert len(diags) == 10
np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
- cp = CubicalPersistence(persistence_dimension=[0, 1], n_jobs=-1)
+ cp = CubicalPersistence(homology_dimensions=[0, 1], n_jobs=-1)
diagsH0H1 = cp.fit_transform(digits)
assert len(diagsH0H1) == 10
for idx in range(10):
@@ -42,9 +42,18 @@ def test_simple_constructor_from_top_cells_list():
def test_simple_constructor_from_flattened_cells():
cells = datasets.load_digits().images[0]
# Not squared (extended) flatten cells
- cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten()
+ flat_cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten()
- cp = CubicalPersistence(persistence_dimension=0, newshape=[10, 8])
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
+ diags = cp.fit_transform([flat_cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+ # Not squared (extended) non-flatten cells
+ cells = np.hstack((cells, np.zeros((cells.shape[0], 2))))
+
+ # The aim of this second part of the test is to resize even if not mandatory
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
diags = cp.fit_transform([cells])
np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)