diff options
Diffstat (limited to 'src/python/test/test_sklearn_cubical_persistence.py')
-rw-r--r-- | src/python/test/test_sklearn_cubical_persistence.py | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py index 506985f1..488495d1 100644 --- a/src/python/test/test_sklearn_cubical_persistence.py +++ b/src/python/test/test_sklearn_cubical_persistence.py @@ -16,29 +16,28 @@ __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2021 Inria" __license__ = "MIT" -CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0., 6.], [0., 8.], [ 0., np.inf]]) +CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0.0, 6.0], [0.0, 8.0], [0.0, np.inf]]) + def test_simple_constructor_from_top_cells(): cells = datasets.load_digits().images[0] - cp = CubicalPersistence(only_this_dim = 0) - np.testing.assert_array_equal(cp._CubicalPersistence__transform(cells), - [CUBICAL_PERSISTENCE_H0_IMG0]) - cp = CubicalPersistence(max_persistence_dimension = 2) + cp = CubicalPersistence(only_this_dim=0) + np.testing.assert_array_equal(cp._CubicalPersistence__transform(cells), [CUBICAL_PERSISTENCE_H0_IMG0]) + cp = CubicalPersistence(max_persistence_dimension=2) diags = cp._CubicalPersistence__transform(cells) assert len(diags) == 3 - np.testing.assert_array_equal(diags[0], - CUBICAL_PERSISTENCE_H0_IMG0) + np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0) + def test_simple_constructor_from_top_cells_list(): digits = datasets.load_digits().images[:10] - cp = CubicalPersistence(only_this_dim = 0, n_jobs=-2) + cp = CubicalPersistence(only_this_dim=0, n_jobs=-2) diags = cp.fit_transform(digits) assert len(diags) == 10 - np.testing.assert_array_equal(diags[0], - CUBICAL_PERSISTENCE_H0_IMG0) + np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0) - cp = CubicalPersistence(max_persistence_dimension = 1, n_jobs=-1) + cp = CubicalPersistence(max_persistence_dimension=1, n_jobs=-1) diagsH0H1 = cp.fit_transform(digits) assert len(diagsH0H1) == 10 for idx in range(10): |