summaryrefslogtreecommitdiff
path: root/src/python/test/test_sklearn_cubical_persistence.py
diff options
context:
space:
mode:
authorVincent Rouvreau <vincent.rouvreau@inria.fr>2022-02-03 10:14:41 +0100
committerVincent Rouvreau <vincent.rouvreau@inria.fr>2022-02-03 10:14:41 +0100
commitaf30ee3c2966d29a4e71893fa2c671aeaeb3497f (patch)
treebd03a230a43b7a9bd58f6c3d7ddad2a36bd03edf /src/python/test/test_sklearn_cubical_persistence.py
parent36457ad5c33e253d4998666a70ce94c914e596a4 (diff)
Add flattened cells test
Diffstat (limited to 'src/python/test/test_sklearn_cubical_persistence.py')
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
index bd728a29..56c44db0 100644
--- a/src/python/test/test_sklearn_cubical_persistence.py
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -38,3 +38,13 @@ def test_simple_constructor_from_top_cells_list():
assert len(diagsH0H1) == 10
for idx in range(10):
np.testing.assert_array_equal(diags[idx], diagsH0H1[idx][0])
+
+def test_simple_constructor_from_flattened_cells():
+ cells = datasets.load_digits().images[0]
+ # Not squared (extended) flatten cells
+ cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten()
+
+ cp = CubicalPersistence(persistence_dimension=0, dimensions=[10, 8])
+ diags = cp.fit_transform([cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)