diff options
author | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-05-31 12:23:29 +0200 |
---|---|---|
committer | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2021-05-31 12:23:29 +0200 |
commit | 5b75186ace327ddc17eb6f06c0ba2485c93235ec (patch) | |
tree | f263e4e6a1f0034ec30ef7f999eab81942373346 /src/python/test | |
parent | 8859128da7386955b00658ff5d71659a5de08c46 (diff) |
code review + parallelization
Diffstat (limited to 'src/python/test')
-rw-r--r-- | src/python/test/test_sklearn_cubical_persistence.py | 25 |
1 files changed, 13 insertions, 12 deletions
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py index f147ffe3..134611c9 100644 --- a/src/python/test/test_sklearn_cubical_persistence.py +++ b/src/python/test/test_sklearn_cubical_persistence.py @@ -10,26 +10,27 @@ from gudhi.sklearn.cubical_persistence import CubicalPersistence import numpy as np +from sklearn import datasets __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2021 Inria" __license__ = "MIT" def test_simple_constructor_from_top_cells(): + cells = datasets.load_digits().images[0] cp = CubicalPersistence(persistence_dim = 0) + np.testing.assert_array_equal(cp._CubicalPersistence__transform(cells), + np.array([[0., 6.], [0., 8.]])) - # The first "0" from sklearn.datasets.load_digits() - bmp = np.array([[ 0., 0., 5., 13., 9., 1., 0., 0.], - [ 0., 0., 13., 15., 10., 15., 5., 0.], - [ 0., 3., 15., 2., 0., 11., 8., 0.], - [ 0., 4., 12., 0., 0., 8., 8., 0.], - [ 0., 5., 8., 0., 0., 9., 8., 0.], - [ 0., 4., 11., 0., 1., 12., 7., 0.], - [ 0., 2., 14., 5., 10., 12., 0., 0.], - [ 0., 0., 6., 13., 10., 0., 0., 0.]]) +def test_simple_constructor_from_top_cells_list(): + digits = datasets.load_digits().images[:10] + cp = CubicalPersistence(persistence_dim = 0, n_jobs=-2) - assert cp.fit_transform(bmp) == np.array([[0., 6.], [0., 8.]]) + diags = cp.fit_transform(digits) + assert len(diags) == 10 + np.testing.assert_array_equal(diags[0], + np.array([[0., 6.], [0., 8.]])) # from gudhi.representations import PersistenceImage -# PersistenceImage(bandwidth=50, weight=lambda x: x[1]**2, im_range=[0,256,0,256], resolution=[20, 20]) -# PI.fit_transform([diag])
\ No newline at end of file +# pi = PersistenceImage(bandwidth=50, weight=lambda x: x[1]**2, im_range=[0,256,0,256], resolution=[20, 20]) +# pi.fit_transform(diags)
\ No newline at end of file |