summaryrefslogtreecommitdiff
path: root/src/python/test
diff options
context:
space:
mode:
authorVincent Rouvreau <10407034+VincentRouvreau@users.noreply.github.com>2022-08-10 10:48:44 +0200
committerGitHub <noreply@github.com>2022-08-10 10:48:44 +0200
commita5978f81faf2aeaa3b3df682caf791aae50fd948 (patch)
tree9f4036e73e8083be95153af91ad761892bc1b8b2 /src/python/test
parent4f83706aa1263c04cb5e8763e1e8eb6c580bed3c (diff)
parent5fdb9e5e1ed77f7ad5a98c563fb9bfa09056271c (diff)
Merge pull request #499 from VincentRouvreau/sklearn_cubical
Scikit learn like cubical interface
Diffstat (limited to 'src/python/test')
-rw-r--r--src/python/test/test_representations_preprocessing.py39
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py59
2 files changed, 98 insertions, 0 deletions
diff --git a/src/python/test/test_representations_preprocessing.py b/src/python/test/test_representations_preprocessing.py
new file mode 100644
index 00000000..838cf30c
--- /dev/null
+++ b/src/python/test/test_representations_preprocessing.py
@@ -0,0 +1,39 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.representations.preprocessing import DimensionSelector
+import numpy as np
+import pytest
+
+H0_0 = np.array([0.0, 0.0])
+H1_0 = np.array([1.0, 0.0])
+H0_1 = np.array([0.0, 1.0])
+H1_1 = np.array([1.0, 1.0])
+H0_2 = np.array([0.0, 2.0])
+H1_2 = np.array([1.0, 2.0])
+
+
+def test_dimension_selector():
+ X = [[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]]
+ ds = DimensionSelector(index=0)
+ h0 = ds.fit_transform(X)
+ np.testing.assert_array_equal(h0[0], H0_0)
+ np.testing.assert_array_equal(h0[1], H0_1)
+ np.testing.assert_array_equal(h0[2], H0_2)
+
+ ds = DimensionSelector(index=1)
+ h1 = ds.fit_transform(X)
+ np.testing.assert_array_equal(h1[0], H1_0)
+ np.testing.assert_array_equal(h1[1], H1_1)
+ np.testing.assert_array_equal(h1[2], H1_2)
+
+ ds = DimensionSelector(index=2)
+ with pytest.raises(IndexError):
+ h2 = ds.fit_transform([[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]])
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
new file mode 100644
index 00000000..1c05a215
--- /dev/null
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -0,0 +1,59 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.sklearn.cubical_persistence import CubicalPersistence
+import numpy as np
+from sklearn import datasets
+
+CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0.0, 6.0], [0.0, 8.0], [0.0, np.inf]])
+
+
+def test_simple_constructor_from_top_cells():
+ cells = datasets.load_digits().images[0]
+ cp = CubicalPersistence(homology_dimensions=0)
+ np.testing.assert_array_equal(cp._CubicalPersistence__transform_only_this_dim(cells), CUBICAL_PERSISTENCE_H0_IMG0)
+ cp = CubicalPersistence(homology_dimensions=[0, 2])
+ diags = cp._CubicalPersistence__transform(cells)
+ assert len(diags) == 2
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+
+def test_simple_constructor_from_top_cells_list():
+ digits = datasets.load_digits().images[:10]
+ cp = CubicalPersistence(homology_dimensions=0, n_jobs=-2)
+
+ diags = cp.fit_transform(digits)
+ assert len(diags) == 10
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+ cp = CubicalPersistence(homology_dimensions=[0, 1], n_jobs=-1)
+ diagsH0H1 = cp.fit_transform(digits)
+ assert len(diagsH0H1) == 10
+ for idx in range(10):
+ np.testing.assert_array_equal(diags[idx], diagsH0H1[idx][0])
+
+def test_simple_constructor_from_flattened_cells():
+ cells = datasets.load_digits().images[0]
+ # Not squared (extended) flatten cells
+ flat_cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten()
+
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
+ diags = cp.fit_transform([flat_cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+ # Not squared (extended) non-flatten cells
+ cells = np.hstack((cells, np.zeros((cells.shape[0], 2))))
+
+ # The aim of this second part of the test is to resize even if not mandatory
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
+ diags = cp.fit_transform([cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)