summaryrefslogtreecommitdiff
path: root/src/python/test/test_sklearn_cubical_persistence.py
blob: 1c05a21586db4d75f9c67a94543382a62b2ff079 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
    See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
    Author(s):       Vincent Rouvreau

    Copyright (C) 2021 Inria

    Modification(s):
      - YYYY/MM Author: Description of the modification
"""

from gudhi.sklearn.cubical_persistence import CubicalPersistence
import numpy as np
from sklearn import datasets

CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0.0, 6.0], [0.0, 8.0], [0.0, np.inf]])


def test_simple_constructor_from_top_cells():
    cells = datasets.load_digits().images[0]
    cp = CubicalPersistence(homology_dimensions=0)
    np.testing.assert_array_equal(cp._CubicalPersistence__transform_only_this_dim(cells), CUBICAL_PERSISTENCE_H0_IMG0)
    cp = CubicalPersistence(homology_dimensions=[0, 2])
    diags = cp._CubicalPersistence__transform(cells)
    assert len(diags) == 2
    np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)


def test_simple_constructor_from_top_cells_list():
    digits = datasets.load_digits().images[:10]
    cp = CubicalPersistence(homology_dimensions=0, n_jobs=-2)

    diags = cp.fit_transform(digits)
    assert len(diags) == 10
    np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)

    cp = CubicalPersistence(homology_dimensions=[0, 1], n_jobs=-1)
    diagsH0H1 = cp.fit_transform(digits)
    assert len(diagsH0H1) == 10
    for idx in range(10):
        np.testing.assert_array_equal(diags[idx], diagsH0H1[idx][0])

def test_simple_constructor_from_flattened_cells():
    cells = datasets.load_digits().images[0]
    # Not squared (extended) flatten cells
    flat_cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten()
    
    cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
    diags = cp.fit_transform([flat_cells])

    np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)

    # Not squared (extended) non-flatten cells
    cells = np.hstack((cells, np.zeros((cells.shape[0], 2))))
    
    # The aim of this second part of the test is to resize even if not mandatory
    cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
    diags = cp.fit_transform([cells])

    np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)