summaryrefslogtreecommitdiff
path: root/src/python/test/test_sklearn_post_processing.py
blob: 3a251d345332a6c72b42521e63f64be9f21f5a99 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
    See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
    Author(s):       Vincent Rouvreau

    Copyright (C) 2021 Inria

    Modification(s):
      - YYYY/MM Author: Description of the modification
"""

from gudhi.sklearn.post_processing import DimensionSelector
import numpy as np
import pytest

__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2021 Inria"
__license__ = "MIT"

H0_0 = np.array([0., 0.])
H1_0 = np.array([1., 0.])
H0_1 = np.array([0., 1.])
H1_1 = np.array([1., 1.])
H0_2 = np.array([0., 2.])
H1_2 = np.array([1., 2.])

def test_dimension_selector():
    X = [[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]]
    ds = DimensionSelector(persistence_dimension = 0, n_jobs=-2)
    h0 = ds.fit_transform(X)
    np.testing.assert_array_equal(h0[0],
                                  H0_0)
    np.testing.assert_array_equal(h0[1],
                                  H0_1)
    np.testing.assert_array_equal(h0[2],
                                  H0_2)
    
    ds = DimensionSelector(persistence_dimension = 1, n_jobs=-1)
    h1 = ds.fit_transform(X)
    np.testing.assert_array_equal(h1[0],
                                  H1_0)
    np.testing.assert_array_equal(h1[1],
                                  H1_1)
    np.testing.assert_array_equal(h1[2],
                                  H1_2)

    ds = DimensionSelector(persistence_dimension = 2, n_jobs=-2)
    with pytest.raises(IndexError):
        h2 = ds.fit_transform([[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]])