summaryrefslogtreecommitdiff
path: root/src/python/test/test_persistence_graphical_tools.py
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-22 18:47:46 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-22 18:47:46 +0200
commitf9b1e50a6adaadf88c4940cb4214f9ecef542144 (patch)
tree6757404fef3464a2e32ca3d9dbac2ae8616c2319 /src/python/test/test_persistence_graphical_tools.py
parent72f72770ccf0d1ddb78a7f23103b2777f407e72c (diff)
black -l 120 modified files
Diffstat (limited to 'src/python/test/test_persistence_graphical_tools.py')
-rw-r--r--src/python/test/test_persistence_graphical_tools.py82
1 files changed, 47 insertions, 35 deletions
diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py
index 1ad1ae23..7d9bae90 100644
--- a/src/python/test/test_persistence_graphical_tools.py
+++ b/src/python/test/test_persistence_graphical_tools.py
@@ -13,6 +13,7 @@ import numpy as np
import matplotlib as plt
import pytest
+
def test_array_handler():
diags = np.array([[1, 2], [3, 4], [5, 6]], np.float)
arr_diags = gd.persistence_graphical_tools._array_handler(diags)
@@ -20,86 +21,97 @@ def test_array_handler():
assert arr_diags[idx][0] == 0
np.testing.assert_array_equal(arr_diags[idx][1], diags[idx])
- diags = [(1., 2.), (3., 4.), (5., 6.)]
+ diags = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]
arr_diags = gd.persistence_graphical_tools._array_handler(diags)
for idx in range(len(diags)):
assert arr_diags[idx][0] == 0
assert arr_diags[idx][1] == diags[idx]
- diags = [(0, (1., 2.)), (0, (3., 4.)), (0, (5., 6.))]
+ diags = [(0, (1.0, 2.0)), (0, (3.0, 4.0)), (0, (5.0, 6.0))]
assert gd.persistence_graphical_tools._array_handler(diags) == diags
+
def test_min_birth_max_death():
diags = [
- (0, (0., float("inf"))),
+ (0, (0.0, float("inf"))),
(0, (0.0983494, float("inf"))),
- (0, (0., 0.122545)),
- (0, (0., 0.12047)),
- (0, (0., 0.118398)),
- (0, (0.118398, 1.)),
- (0, (0., 0.117908)),
- (0, (0., 0.112307)),
- (0, (0., 0.107535)),
- (0, (0., 0.106382)),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ (0, (0.0, 0.118398)),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.117908)),
+ (0, (0.0, 0.112307)),
+ (0, (0.0, 0.107535)),
+ (0, (0.0, 0.106382)),
]
- assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0., 1.)
- assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.) == (0., 5.)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0.0, 1.0)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (0.0, 5.0)
+
def test_limit_min_birth_max_death():
diags = [
- (0, (2., float("inf"))),
- (0, (2., float("inf"))),
+ (0, (2.0, float("inf"))),
+ (0, (2.0, float("inf"))),
]
- assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2., 3.)
- assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band = 4.) == (2., 6.)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2.0, 3.0)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (2.0, 6.0)
+
def test_limit_to_max_intervals():
diags = [
- (0, (0., float("inf"))),
+ (0, (0.0, float("inf"))),
(0, (0.0983494, float("inf"))),
- (0, (0., 0.122545)),
- (0, (0., 0.12047)),
- (0, (0., 0.118398)),
- (0, (0.118398, 1.)),
- (0, (0., 0.117908)),
- (0, (0., 0.112307)),
- (0, (0., 0.107535)),
- (0, (0., 0.106382)),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ (0, (0.0, 0.118398)),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.117908)),
+ (0, (0.0, 0.112307)),
+ (0, (0.0, 0.107535)),
+ (0, (0.0, 0.106382)),
]
# check no warnings if max_intervals equals to the diagrams number
with pytest.warns(None) as record:
- truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(diags, 10,
- key = lambda life_time: life_time[1][1] - life_time[1][0])
+ truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(
+ diags, 10, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
# check diagrams are not sorted
assert truncated_diags == diags
assert len(record) == 0
# check warning if max_intervals lower than the diagrams number
with pytest.warns(UserWarning) as record:
- truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(diags, 5,
- key = lambda life_time: life_time[1][1] - life_time[1][0])
+ truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(
+ diags, 5, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
# check diagrams are truncated and sorted by life time
- assert truncated_diags == [(0, (0., float("inf"))),
- (0, (0.0983494, float("inf"))),
- (0, (0.118398, 1.0)),
- (0, (0., 0.122545)),
- (0, (0., 0.12047))]
+ assert truncated_diags == [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ ]
assert len(record) == 1
+
def _limit_plot_persistence(function):
pplot = function(persistence=[()])
assert issubclass(type(pplot), plt.axes.SubplotBase)
pplot = function(persistence=[(0, float("inf"))])
assert issubclass(type(pplot), plt.axes.SubplotBase)
+
def test_limit_plot_persistence():
for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]:
_limit_plot_persistence(function)
+
def _non_existing_persistence_file(function):
with pytest.raises(FileNotFoundError):
function(persistence_file="pouetpouettralala.toubiloubabdou")
+
def test_non_existing_persistence_file():
for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]:
_non_existing_persistence_file(function)