summaryrefslogtreecommitdiff
path: root/src/python/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test')
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py25
1 files changed, 13 insertions, 12 deletions
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
index f147ffe3..134611c9 100644
--- a/src/python/test/test_sklearn_cubical_persistence.py
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -10,26 +10,27 @@
from gudhi.sklearn.cubical_persistence import CubicalPersistence
import numpy as np
+from sklearn import datasets
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2021 Inria"
__license__ = "MIT"
def test_simple_constructor_from_top_cells():
+ cells = datasets.load_digits().images[0]
cp = CubicalPersistence(persistence_dim = 0)
+ np.testing.assert_array_equal(cp._CubicalPersistence__transform(cells),
+ np.array([[0., 6.], [0., 8.]]))
- # The first "0" from sklearn.datasets.load_digits()
- bmp = np.array([[ 0., 0., 5., 13., 9., 1., 0., 0.],
- [ 0., 0., 13., 15., 10., 15., 5., 0.],
- [ 0., 3., 15., 2., 0., 11., 8., 0.],
- [ 0., 4., 12., 0., 0., 8., 8., 0.],
- [ 0., 5., 8., 0., 0., 9., 8., 0.],
- [ 0., 4., 11., 0., 1., 12., 7., 0.],
- [ 0., 2., 14., 5., 10., 12., 0., 0.],
- [ 0., 0., 6., 13., 10., 0., 0., 0.]])
+def test_simple_constructor_from_top_cells_list():
+ digits = datasets.load_digits().images[:10]
+ cp = CubicalPersistence(persistence_dim = 0, n_jobs=-2)
- assert cp.fit_transform(bmp) == np.array([[0., 6.], [0., 8.]])
+ diags = cp.fit_transform(digits)
+ assert len(diags) == 10
+ np.testing.assert_array_equal(diags[0],
+ np.array([[0., 6.], [0., 8.]]))
# from gudhi.representations import PersistenceImage
-# PersistenceImage(bandwidth=50, weight=lambda x: x[1]**2, im_range=[0,256,0,256], resolution=[20, 20])
-# PI.fit_transform([diag]) \ No newline at end of file
+# pi = PersistenceImage(bandwidth=50, weight=lambda x: x[1]**2, im_range=[0,256,0,256], resolution=[20, 20])
+# pi.fit_transform(diags) \ No newline at end of file