summaryrefslogtreecommitdiff
path: root/src/python/test/test_sklearn_post_processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test/test_sklearn_post_processing.py')
-rw-r--r--src/python/test/test_sklearn_post_processing.py41
1 files changed, 18 insertions, 23 deletions
diff --git a/src/python/test/test_sklearn_post_processing.py b/src/python/test/test_sklearn_post_processing.py
index 3a251d34..60bf8162 100644
--- a/src/python/test/test_sklearn_post_processing.py
+++ b/src/python/test/test_sklearn_post_processing.py
@@ -16,33 +16,28 @@ __author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2021 Inria"
__license__ = "MIT"
-H0_0 = np.array([0., 0.])
-H1_0 = np.array([1., 0.])
-H0_1 = np.array([0., 1.])
-H1_1 = np.array([1., 1.])
-H0_2 = np.array([0., 2.])
-H1_2 = np.array([1., 2.])
+H0_0 = np.array([0.0, 0.0])
+H1_0 = np.array([1.0, 0.0])
+H0_1 = np.array([0.0, 1.0])
+H1_1 = np.array([1.0, 1.0])
+H0_2 = np.array([0.0, 2.0])
+H1_2 = np.array([1.0, 2.0])
+
def test_dimension_selector():
X = [[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]]
- ds = DimensionSelector(persistence_dimension = 0, n_jobs=-2)
+ ds = DimensionSelector(persistence_dimension=0, n_jobs=-2)
h0 = ds.fit_transform(X)
- np.testing.assert_array_equal(h0[0],
- H0_0)
- np.testing.assert_array_equal(h0[1],
- H0_1)
- np.testing.assert_array_equal(h0[2],
- H0_2)
-
- ds = DimensionSelector(persistence_dimension = 1, n_jobs=-1)
+ np.testing.assert_array_equal(h0[0], H0_0)
+ np.testing.assert_array_equal(h0[1], H0_1)
+ np.testing.assert_array_equal(h0[2], H0_2)
+
+ ds = DimensionSelector(persistence_dimension=1, n_jobs=-1)
h1 = ds.fit_transform(X)
- np.testing.assert_array_equal(h1[0],
- H1_0)
- np.testing.assert_array_equal(h1[1],
- H1_1)
- np.testing.assert_array_equal(h1[2],
- H1_2)
-
- ds = DimensionSelector(persistence_dimension = 2, n_jobs=-2)
+ np.testing.assert_array_equal(h1[0], H1_0)
+ np.testing.assert_array_equal(h1[1], H1_1)
+ np.testing.assert_array_equal(h1[2], H1_2)
+
+ ds = DimensionSelector(persistence_dimension=2, n_jobs=-2)
with pytest.raises(IndexError):
h2 = ds.fit_transform([[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]])