summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-22 18:47:46 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2021-06-22 18:47:46 +0200
commitf9b1e50a6adaadf88c4940cb4214f9ecef542144 (patch)
tree6757404fef3464a2e32ca3d9dbac2ae8616c2319
parent72f72770ccf0d1ddb78a7f23103b2777f407e72c (diff)
black -l 120 modified files
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py144
-rw-r--r--src/python/test/test_persistence_graphical_tools.py82
2 files changed, 120 insertions, 106 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 9bfc19e0..837218ca 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -25,6 +25,7 @@ __license__ = "MIT"
_gudhi_matplotlib_use_tex = True
+
def __min_birth_max_death(persistence, band=0.0):
"""This function returns (min_birth, max_death) from the persistence.
@@ -49,23 +50,24 @@ def __min_birth_max_death(persistence, band=0.0):
max_death += band
# can happen if only points at inf death
if min_birth == max_death:
- max_death = max_death + 1.
+ max_death = max_death + 1.0
return (min_birth, max_death)
def _array_handler(a):
- '''
+ """
:param a: if array, assumes it is a (n x 2) np.array and return a
persistence-compatible list (padding with 0), so that the
plot can be performed seamlessly.
- '''
+ """
if isinstance(a[0][1], np.floating) or isinstance(a[0][1], float):
return [[0, x] for x in a]
else:
return a
+
def _limit_to_max_intervals(persistence, max_intervals, key):
- '''This function returns truncated persistence if length is bigger than max_intervals.
+ """This function returns truncated persistence if length is bigger than max_intervals.
:param persistence: Persistence intervals values list. Can be grouped by dimension or not.
:type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
:param max_intervals: maximal number of intervals to display.
@@ -74,16 +76,18 @@ def _limit_to_max_intervals(persistence, max_intervals, key):
:type max_intervals: int.
:param key: key function for sort algorithm.
:type key: function or lambda.
- '''
+ """
if max_intervals > 0 and max_intervals < len(persistence):
- warnings.warn('There are %s intervals given as input, whereas max_intervals is set to %s.' %
- (len(persistence), max_intervals)
- )
+ warnings.warn(
+ "There are %s intervals given as input, whereas max_intervals is set to %s."
+ % (len(persistence), max_intervals)
+ )
# Sort by life time, then takes only the max_intervals elements
- return sorted(persistence, key = key, reverse = True)[:max_intervals]
+ return sorted(persistence, key=key, reverse=True)[:max_intervals]
else:
return persistence
+
@lru_cache(maxsize=1)
def _matplotlib_can_use_tex():
"""This function returns True if matplotlib can deal with LaTeX, False otherwise.
@@ -91,6 +95,7 @@ def _matplotlib_can_use_tex():
"""
try:
from matplotlib import checkdep_usetex
+
return checkdep_usetex(True)
except ImportError:
warnings.warn("This function is not available, you may be missing matplotlib.")
@@ -144,20 +149,19 @@ def plot_persistence_barcode(
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import rc
+
if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
else:
- plt.rc('text', usetex=False)
- plt.rc('font', family='DejaVu Sans')
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
- diag = read_persistence_intervals_grouped_by_dimension(
- persistence_file=persistence_file
- )
+ diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
for key in diag.keys():
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
@@ -166,12 +170,13 @@ def plot_persistence_barcode(
try:
persistence = _array_handler(persistence)
- persistence = _limit_to_max_intervals(persistence, max_intervals,
- key = lambda life_time: life_time[1][1] - life_time[1][0])
+ persistence = _limit_to_max_intervals(
+ persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
(min_birth, max_death) = __min_birth_max_death(persistence)
persistence = sorted(persistence, key=lambda birth: birth[1][0])
except IndexError:
- min_birth, max_death = 0., 1.
+ min_birth, max_death = 0.0, 1.0
pass
delta = (max_death - min_birth) * inf_delta
@@ -218,11 +223,7 @@ def plot_persistence_barcode(
if legend:
dimensions = list(set(item[0] for item in persistence))
axes.legend(
- handles=[
- mpatches.Patch(color=colormap[dim], label=str(dim))
- for dim in dimensions
- ],
- loc="lower right",
+ handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right",
)
axes.set_title("Persistence barcode", fontsize=fontsize)
@@ -247,7 +248,7 @@ def plot_persistence_diagram(
colormap=None,
axes=None,
fontsize=16,
- greyblock=True
+ greyblock=True,
):
"""This function plots the persistence diagram from persistence values
list, a np.array of shape (N x 2) representing a diagram in a single
@@ -289,20 +290,19 @@ def plot_persistence_diagram(
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import rc
+
if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
else:
- plt.rc('text', usetex=False)
- plt.rc('font', family='DejaVu Sans')
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
- diag = read_persistence_intervals_grouped_by_dimension(
- persistence_file=persistence_file
- )
+ diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
for key in diag.keys():
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
@@ -311,11 +311,12 @@ def plot_persistence_diagram(
try:
persistence = _array_handler(persistence)
- persistence = _limit_to_max_intervals(persistence, max_intervals,
- key = lambda life_time: life_time[1][1] - life_time[1][0])
+ persistence = _limit_to_max_intervals(
+ persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
min_birth, max_death = __min_birth_max_death(persistence, band)
except IndexError:
- min_birth, max_death = 0., 1.
+ min_birth, max_death = 0.0, 1.0
pass
delta = (max_death - min_birth) * inf_delta
@@ -335,8 +336,13 @@ def plot_persistence_diagram(
axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
# lower diag patch
if greyblock:
- axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]],
- fill=True, color='lightgrey'))
+ axes.add_patch(
+ mpatches.Polygon(
+ [[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]],
+ fill=True,
+ color="lightgrey",
+ )
+ )
# line display of equation : birth = death
axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
@@ -347,17 +353,12 @@ def plot_persistence_diagram(
if float(interval[1][1]) != float("inf"):
# Finite death case
axes.scatter(
- interval[1][0],
- interval[1][1],
- alpha=alpha,
- color=colormap[interval[0]],
+ interval[1][0], interval[1][1], alpha=alpha, color=colormap[interval[0]],
)
else:
pts_at_infty = True
# Infinite death case for diagram to be nicer
- axes.scatter(
- interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]
- )
+ axes.scatter(interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]])
except IndexError:
pass
if pts_at_infty:
@@ -365,27 +366,22 @@ def plot_persistence_diagram(
axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha)
# Infinity label
yt = axes.get_yticks()
- yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity
+ yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity
yt = np.append(yt, infinity)
ytl = ["%.3f" % e for e in yt] # to avoid float precision error
- ytl[-1] = r'$+\infty$'
+ ytl[-1] = r"$+\infty$"
axes.set_yticks(yt)
axes.set_yticklabels(ytl)
if legend:
dimensions = list(set(item[0] for item in persistence))
- axes.legend(
- handles=[
- mpatches.Patch(color=colormap[dim], label=str(dim))
- for dim in dimensions
- ]
- )
+ axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions])
axes.set_xlabel("Birth", fontsize=fontsize)
axes.set_ylabel("Death", fontsize=fontsize)
axes.set_title("Persistence diagram", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
- axes.axis([axis_start, axis_end, axis_start, infinity + delta/2])
+ axes.axis([axis_start, axis_end, axis_start, infinity + delta / 2])
return axes
except ImportError:
@@ -403,7 +399,7 @@ def plot_persistence_density(
legend=False,
axes=None,
fontsize=16,
- greyblock=False
+ greyblock=False,
):
"""This function plots the persistence density from persistence
values list, np.array of shape (N x 2) representing a diagram
@@ -463,12 +459,13 @@ def plot_persistence_density(
import matplotlib.patches as mpatches
from scipy.stats import kde
from matplotlib import rc
+
if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
else:
- plt.rc('text', usetex=False)
- plt.rc('font', family='DejaVu Sans')
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
if persistence_file != "":
if dimension is None:
@@ -499,8 +496,11 @@ def plot_persistence_density(
]
)
persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
- persistence_dim = np.array(_limit_to_max_intervals(persistence_dim, max_intervals,
- key = lambda life_time: life_time[1] - life_time[0]))
+ persistence_dim = np.array(
+ _limit_to_max_intervals(
+ persistence_dim, max_intervals, key=lambda life_time: life_time[1] - life_time[0]
+ )
+ )
# Set as numpy array birth and death (remove undefined values - inf and NaN)
birth = persistence_dim[:, 0]
@@ -513,19 +513,18 @@ def plot_persistence_density(
# Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
k = kde.gaussian_kde([birth, death], bw_method=bw_method)
xi, yi = np.mgrid[
- birth_min : birth_max : nbins * 1j,
- death_min : death_max : nbins * 1j,
+ birth_min : birth_max : nbins * 1j, death_min : death_max : nbins * 1j,
]
zi = k(np.vstack([xi.flatten(), yi.flatten()]))
# Make the plot
- img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading='auto')
+ img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto")
# IndexError on empty diagrams, ValueError on only inf death values
except (IndexError, ValueError):
- birth_min = 0.
- birth_max = 1.
- death_min = 0.
- death_max = 1.
+ birth_min = 0.0
+ birth_max = 1.0
+ death_min = 0.0
+ death_max = 1.0
pass
# line display of equation : birth = death
@@ -533,10 +532,13 @@ def plot_persistence_density(
axes.plot(x, x, color="k", linewidth=1.0)
if greyblock:
- axes.add_patch(mpatches.Polygon([[birth_min, birth_min],
- [death_max, birth_min],
- [death_max, death_max]],
- fill=True, color='lightgrey'))
+ axes.add_patch(
+ mpatches.Polygon(
+ [[birth_min, birth_min], [death_max, birth_min], [death_max, death_max]],
+ fill=True,
+ color="lightgrey",
+ )
+ )
if legend:
plt.colorbar(img, ax=axes)
diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py
index 1ad1ae23..7d9bae90 100644
--- a/src/python/test/test_persistence_graphical_tools.py
+++ b/src/python/test/test_persistence_graphical_tools.py
@@ -13,6 +13,7 @@ import numpy as np
import matplotlib as plt
import pytest
+
def test_array_handler():
diags = np.array([[1, 2], [3, 4], [5, 6]], np.float)
arr_diags = gd.persistence_graphical_tools._array_handler(diags)
@@ -20,86 +21,97 @@ def test_array_handler():
assert arr_diags[idx][0] == 0
np.testing.assert_array_equal(arr_diags[idx][1], diags[idx])
- diags = [(1., 2.), (3., 4.), (5., 6.)]
+ diags = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]
arr_diags = gd.persistence_graphical_tools._array_handler(diags)
for idx in range(len(diags)):
assert arr_diags[idx][0] == 0
assert arr_diags[idx][1] == diags[idx]
- diags = [(0, (1., 2.)), (0, (3., 4.)), (0, (5., 6.))]
+ diags = [(0, (1.0, 2.0)), (0, (3.0, 4.0)), (0, (5.0, 6.0))]
assert gd.persistence_graphical_tools._array_handler(diags) == diags
+
def test_min_birth_max_death():
diags = [
- (0, (0., float("inf"))),
+ (0, (0.0, float("inf"))),
(0, (0.0983494, float("inf"))),
- (0, (0., 0.122545)),
- (0, (0., 0.12047)),
- (0, (0., 0.118398)),
- (0, (0.118398, 1.)),
- (0, (0., 0.117908)),
- (0, (0., 0.112307)),
- (0, (0., 0.107535)),
- (0, (0., 0.106382)),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ (0, (0.0, 0.118398)),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.117908)),
+ (0, (0.0, 0.112307)),
+ (0, (0.0, 0.107535)),
+ (0, (0.0, 0.106382)),
]
- assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0., 1.)
- assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.) == (0., 5.)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0.0, 1.0)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (0.0, 5.0)
+
def test_limit_min_birth_max_death():
diags = [
- (0, (2., float("inf"))),
- (0, (2., float("inf"))),
+ (0, (2.0, float("inf"))),
+ (0, (2.0, float("inf"))),
]
- assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2., 3.)
- assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band = 4.) == (2., 6.)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2.0, 3.0)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (2.0, 6.0)
+
def test_limit_to_max_intervals():
diags = [
- (0, (0., float("inf"))),
+ (0, (0.0, float("inf"))),
(0, (0.0983494, float("inf"))),
- (0, (0., 0.122545)),
- (0, (0., 0.12047)),
- (0, (0., 0.118398)),
- (0, (0.118398, 1.)),
- (0, (0., 0.117908)),
- (0, (0., 0.112307)),
- (0, (0., 0.107535)),
- (0, (0., 0.106382)),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ (0, (0.0, 0.118398)),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.117908)),
+ (0, (0.0, 0.112307)),
+ (0, (0.0, 0.107535)),
+ (0, (0.0, 0.106382)),
]
# check no warnings if max_intervals equals to the diagrams number
with pytest.warns(None) as record:
- truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(diags, 10,
- key = lambda life_time: life_time[1][1] - life_time[1][0])
+ truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(
+ diags, 10, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
# check diagrams are not sorted
assert truncated_diags == diags
assert len(record) == 0
# check warning if max_intervals lower than the diagrams number
with pytest.warns(UserWarning) as record:
- truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(diags, 5,
- key = lambda life_time: life_time[1][1] - life_time[1][0])
+ truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(
+ diags, 5, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
# check diagrams are truncated and sorted by life time
- assert truncated_diags == [(0, (0., float("inf"))),
- (0, (0.0983494, float("inf"))),
- (0, (0.118398, 1.0)),
- (0, (0., 0.122545)),
- (0, (0., 0.12047))]
+ assert truncated_diags == [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ ]
assert len(record) == 1
+
def _limit_plot_persistence(function):
pplot = function(persistence=[()])
assert issubclass(type(pplot), plt.axes.SubplotBase)
pplot = function(persistence=[(0, float("inf"))])
assert issubclass(type(pplot), plt.axes.SubplotBase)
+
def test_limit_plot_persistence():
for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]:
_limit_plot_persistence(function)
+
def _non_existing_persistence_file(function):
with pytest.raises(FileNotFoundError):
function(persistence_file="pouetpouettralala.toubiloubabdou")
+
def test_non_existing_persistence_file():
for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]:
_non_existing_persistence_file(function)