From 72f72770ccf0d1ddb78a7f23103b2777f407e72c Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Thu, 17 Jun 2021 16:29:18 +0200 Subject: Forgot to add test file for graphical --- .../test/test_persistence_graphical_tools.py | 105 +++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 src/python/test/test_persistence_graphical_tools.py (limited to 'src/python/test/test_persistence_graphical_tools.py') diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py new file mode 100644 index 00000000..1ad1ae23 --- /dev/null +++ b/src/python/test/test_persistence_graphical_tools.py @@ -0,0 +1,105 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Vincent Rouvreau + + Copyright (C) 2021 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +import gudhi as gd +import numpy as np +import matplotlib as plt +import pytest + +def test_array_handler(): + diags = np.array([[1, 2], [3, 4], [5, 6]], np.float) + arr_diags = gd.persistence_graphical_tools._array_handler(diags) + for idx in range(len(diags)): + assert arr_diags[idx][0] == 0 + np.testing.assert_array_equal(arr_diags[idx][1], diags[idx]) + + diags = [(1., 2.), (3., 4.), (5., 6.)] + arr_diags = gd.persistence_graphical_tools._array_handler(diags) + for idx in range(len(diags)): + assert arr_diags[idx][0] == 0 + assert arr_diags[idx][1] == diags[idx] + + diags = [(0, (1., 2.)), (0, (3., 4.)), (0, (5., 6.))] + assert gd.persistence_graphical_tools._array_handler(diags) == diags + +def test_min_birth_max_death(): + diags = [ + (0, (0., float("inf"))), + (0, (0.0983494, float("inf"))), + (0, (0., 0.122545)), + (0, (0., 0.12047)), + (0, (0., 0.118398)), + (0, (0.118398, 1.)), + (0, (0., 0.117908)), + (0, (0., 0.112307)), + (0, (0., 0.107535)), + (0, (0., 0.106382)), + ] + assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0., 1.) + assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.) == (0., 5.) + +def test_limit_min_birth_max_death(): + diags = [ + (0, (2., float("inf"))), + (0, (2., float("inf"))), + ] + assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2., 3.) + assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band = 4.) == (2., 6.) + +def test_limit_to_max_intervals(): + diags = [ + (0, (0., float("inf"))), + (0, (0.0983494, float("inf"))), + (0, (0., 0.122545)), + (0, (0., 0.12047)), + (0, (0., 0.118398)), + (0, (0.118398, 1.)), + (0, (0., 0.117908)), + (0, (0., 0.112307)), + (0, (0., 0.107535)), + (0, (0., 0.106382)), + ] + # check no warnings if max_intervals equals to the diagrams number + with pytest.warns(None) as record: + truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(diags, 10, + key = lambda life_time: life_time[1][1] - life_time[1][0]) + # check diagrams are not sorted + assert truncated_diags == diags + assert len(record) == 0 + + # check warning if max_intervals lower than the diagrams number + with pytest.warns(UserWarning) as record: + truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(diags, 5, + key = lambda life_time: life_time[1][1] - life_time[1][0]) + # check diagrams are truncated and sorted by life time + assert truncated_diags == [(0, (0., float("inf"))), + (0, (0.0983494, float("inf"))), + (0, (0.118398, 1.0)), + (0, (0., 0.122545)), + (0, (0., 0.12047))] + assert len(record) == 1 + +def _limit_plot_persistence(function): + pplot = function(persistence=[()]) + assert issubclass(type(pplot), plt.axes.SubplotBase) + pplot = function(persistence=[(0, float("inf"))]) + assert issubclass(type(pplot), plt.axes.SubplotBase) + +def test_limit_plot_persistence(): + for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]: + _limit_plot_persistence(function) + +def _non_existing_persistence_file(function): + with pytest.raises(FileNotFoundError): + function(persistence_file="pouetpouettralala.toubiloubabdou") + +def test_non_existing_persistence_file(): + for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]: + _non_existing_persistence_file(function) -- cgit v1.2.3 From f9b1e50a6adaadf88c4940cb4214f9ecef542144 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Tue, 22 Jun 2021 18:47:46 +0200 Subject: black -l 120 modified files --- src/python/gudhi/persistence_graphical_tools.py | 144 +++++++++++---------- .../test/test_persistence_graphical_tools.py | 82 +++++++----- 2 files changed, 120 insertions(+), 106 deletions(-) (limited to 'src/python/test/test_persistence_graphical_tools.py') diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 9bfc19e0..837218ca 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -25,6 +25,7 @@ __license__ = "MIT" _gudhi_matplotlib_use_tex = True + def __min_birth_max_death(persistence, band=0.0): """This function returns (min_birth, max_death) from the persistence. @@ -49,23 +50,24 @@ def __min_birth_max_death(persistence, band=0.0): max_death += band # can happen if only points at inf death if min_birth == max_death: - max_death = max_death + 1. + max_death = max_death + 1.0 return (min_birth, max_death) def _array_handler(a): - ''' + """ :param a: if array, assumes it is a (n x 2) np.array and return a persistence-compatible list (padding with 0), so that the plot can be performed seamlessly. - ''' + """ if isinstance(a[0][1], np.floating) or isinstance(a[0][1], float): return [[0, x] for x in a] else: return a + def _limit_to_max_intervals(persistence, max_intervals, key): - '''This function returns truncated persistence if length is bigger than max_intervals. + """This function returns truncated persistence if length is bigger than max_intervals. :param persistence: Persistence intervals values list. Can be grouped by dimension or not. :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death). :param max_intervals: maximal number of intervals to display. @@ -74,16 +76,18 @@ def _limit_to_max_intervals(persistence, max_intervals, key): :type max_intervals: int. :param key: key function for sort algorithm. :type key: function or lambda. - ''' + """ if max_intervals > 0 and max_intervals < len(persistence): - warnings.warn('There are %s intervals given as input, whereas max_intervals is set to %s.' % - (len(persistence), max_intervals) - ) + warnings.warn( + "There are %s intervals given as input, whereas max_intervals is set to %s." + % (len(persistence), max_intervals) + ) # Sort by life time, then takes only the max_intervals elements - return sorted(persistence, key = key, reverse = True)[:max_intervals] + return sorted(persistence, key=key, reverse=True)[:max_intervals] else: return persistence + @lru_cache(maxsize=1) def _matplotlib_can_use_tex(): """This function returns True if matplotlib can deal with LaTeX, False otherwise. @@ -91,6 +95,7 @@ def _matplotlib_can_use_tex(): """ try: from matplotlib import checkdep_usetex + return checkdep_usetex(True) except ImportError: warnings.warn("This function is not available, you may be missing matplotlib.") @@ -144,20 +149,19 @@ def plot_persistence_barcode( import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib import rc + if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex(): - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + plt.rc("text", usetex=True) + plt.rc("font", family="serif") else: - plt.rc('text', usetex=False) - plt.rc('font', family='DejaVu Sans') + plt.rc("text", usetex=False) + plt.rc("font", family="DejaVu Sans") if persistence_file != "": if path.isfile(persistence_file): # Reset persistence persistence = [] - diag = read_persistence_intervals_grouped_by_dimension( - persistence_file=persistence_file - ) + diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) for key in diag.keys(): for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) @@ -166,12 +170,13 @@ def plot_persistence_barcode( try: persistence = _array_handler(persistence) - persistence = _limit_to_max_intervals(persistence, max_intervals, - key = lambda life_time: life_time[1][1] - life_time[1][0]) + persistence = _limit_to_max_intervals( + persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) (min_birth, max_death) = __min_birth_max_death(persistence) persistence = sorted(persistence, key=lambda birth: birth[1][0]) except IndexError: - min_birth, max_death = 0., 1. + min_birth, max_death = 0.0, 1.0 pass delta = (max_death - min_birth) * inf_delta @@ -218,11 +223,7 @@ def plot_persistence_barcode( if legend: dimensions = list(set(item[0] for item in persistence)) axes.legend( - handles=[ - mpatches.Patch(color=colormap[dim], label=str(dim)) - for dim in dimensions - ], - loc="lower right", + handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right", ) axes.set_title("Persistence barcode", fontsize=fontsize) @@ -247,7 +248,7 @@ def plot_persistence_diagram( colormap=None, axes=None, fontsize=16, - greyblock=True + greyblock=True, ): """This function plots the persistence diagram from persistence values list, a np.array of shape (N x 2) representing a diagram in a single @@ -289,20 +290,19 @@ def plot_persistence_diagram( import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib import rc + if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex(): - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + plt.rc("text", usetex=True) + plt.rc("font", family="serif") else: - plt.rc('text', usetex=False) - plt.rc('font', family='DejaVu Sans') + plt.rc("text", usetex=False) + plt.rc("font", family="DejaVu Sans") if persistence_file != "": if path.isfile(persistence_file): # Reset persistence persistence = [] - diag = read_persistence_intervals_grouped_by_dimension( - persistence_file=persistence_file - ) + diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file) for key in diag.keys(): for persistence_interval in diag[key]: persistence.append((key, persistence_interval)) @@ -311,11 +311,12 @@ def plot_persistence_diagram( try: persistence = _array_handler(persistence) - persistence = _limit_to_max_intervals(persistence, max_intervals, - key = lambda life_time: life_time[1][1] - life_time[1][0]) + persistence = _limit_to_max_intervals( + persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) min_birth, max_death = __min_birth_max_death(persistence, band) except IndexError: - min_birth, max_death = 0., 1. + min_birth, max_death = 0.0, 1.0 pass delta = (max_death - min_birth) * inf_delta @@ -335,8 +336,13 @@ def plot_persistence_diagram( axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red") # lower diag patch if greyblock: - axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], - fill=True, color='lightgrey')) + axes.add_patch( + mpatches.Polygon( + [[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], + fill=True, + color="lightgrey", + ) + ) # line display of equation : birth = death axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") @@ -347,17 +353,12 @@ def plot_persistence_diagram( if float(interval[1][1]) != float("inf"): # Finite death case axes.scatter( - interval[1][0], - interval[1][1], - alpha=alpha, - color=colormap[interval[0]], + interval[1][0], interval[1][1], alpha=alpha, color=colormap[interval[0]], ) else: pts_at_infty = True # Infinite death case for diagram to be nicer - axes.scatter( - interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]] - ) + axes.scatter(interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]) except IndexError: pass if pts_at_infty: @@ -365,27 +366,22 @@ def plot_persistence_diagram( axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha) # Infinity label yt = axes.get_yticks() - yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity + yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity yt = np.append(yt, infinity) ytl = ["%.3f" % e for e in yt] # to avoid float precision error - ytl[-1] = r'$+\infty$' + ytl[-1] = r"$+\infty$" axes.set_yticks(yt) axes.set_yticklabels(ytl) if legend: dimensions = list(set(item[0] for item in persistence)) - axes.legend( - handles=[ - mpatches.Patch(color=colormap[dim], label=str(dim)) - for dim in dimensions - ] - ) + axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions]) axes.set_xlabel("Birth", fontsize=fontsize) axes.set_ylabel("Death", fontsize=fontsize) axes.set_title("Persistence diagram", fontsize=fontsize) # Ends plot on infinity value and starts a little bit before min_birth - axes.axis([axis_start, axis_end, axis_start, infinity + delta/2]) + axes.axis([axis_start, axis_end, axis_start, infinity + delta / 2]) return axes except ImportError: @@ -403,7 +399,7 @@ def plot_persistence_density( legend=False, axes=None, fontsize=16, - greyblock=False + greyblock=False, ): """This function plots the persistence density from persistence values list, np.array of shape (N x 2) representing a diagram @@ -463,12 +459,13 @@ def plot_persistence_density( import matplotlib.patches as mpatches from scipy.stats import kde from matplotlib import rc + if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex(): - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + plt.rc("text", usetex=True) + plt.rc("font", family="serif") else: - plt.rc('text', usetex=False) - plt.rc('font', family='DejaVu Sans') + plt.rc("text", usetex=False) + plt.rc("font", family="DejaVu Sans") if persistence_file != "": if dimension is None: @@ -499,8 +496,11 @@ def plot_persistence_density( ] ) persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])] - persistence_dim = np.array(_limit_to_max_intervals(persistence_dim, max_intervals, - key = lambda life_time: life_time[1] - life_time[0])) + persistence_dim = np.array( + _limit_to_max_intervals( + persistence_dim, max_intervals, key=lambda life_time: life_time[1] - life_time[0] + ) + ) # Set as numpy array birth and death (remove undefined values - inf and NaN) birth = persistence_dim[:, 0] @@ -513,19 +513,18 @@ def plot_persistence_density( # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents k = kde.gaussian_kde([birth, death], bw_method=bw_method) xi, yi = np.mgrid[ - birth_min : birth_max : nbins * 1j, - death_min : death_max : nbins * 1j, + birth_min : birth_max : nbins * 1j, death_min : death_max : nbins * 1j, ] zi = k(np.vstack([xi.flatten(), yi.flatten()])) # Make the plot - img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading='auto') + img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto") # IndexError on empty diagrams, ValueError on only inf death values except (IndexError, ValueError): - birth_min = 0. - birth_max = 1. - death_min = 0. - death_max = 1. + birth_min = 0.0 + birth_max = 1.0 + death_min = 0.0 + death_max = 1.0 pass # line display of equation : birth = death @@ -533,10 +532,13 @@ def plot_persistence_density( axes.plot(x, x, color="k", linewidth=1.0) if greyblock: - axes.add_patch(mpatches.Polygon([[birth_min, birth_min], - [death_max, birth_min], - [death_max, death_max]], - fill=True, color='lightgrey')) + axes.add_patch( + mpatches.Polygon( + [[birth_min, birth_min], [death_max, birth_min], [death_max, death_max]], + fill=True, + color="lightgrey", + ) + ) if legend: plt.colorbar(img, ax=axes) diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py index 1ad1ae23..7d9bae90 100644 --- a/src/python/test/test_persistence_graphical_tools.py +++ b/src/python/test/test_persistence_graphical_tools.py @@ -13,6 +13,7 @@ import numpy as np import matplotlib as plt import pytest + def test_array_handler(): diags = np.array([[1, 2], [3, 4], [5, 6]], np.float) arr_diags = gd.persistence_graphical_tools._array_handler(diags) @@ -20,86 +21,97 @@ def test_array_handler(): assert arr_diags[idx][0] == 0 np.testing.assert_array_equal(arr_diags[idx][1], diags[idx]) - diags = [(1., 2.), (3., 4.), (5., 6.)] + diags = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)] arr_diags = gd.persistence_graphical_tools._array_handler(diags) for idx in range(len(diags)): assert arr_diags[idx][0] == 0 assert arr_diags[idx][1] == diags[idx] - diags = [(0, (1., 2.)), (0, (3., 4.)), (0, (5., 6.))] + diags = [(0, (1.0, 2.0)), (0, (3.0, 4.0)), (0, (5.0, 6.0))] assert gd.persistence_graphical_tools._array_handler(diags) == diags + def test_min_birth_max_death(): diags = [ - (0, (0., float("inf"))), + (0, (0.0, float("inf"))), (0, (0.0983494, float("inf"))), - (0, (0., 0.122545)), - (0, (0., 0.12047)), - (0, (0., 0.118398)), - (0, (0.118398, 1.)), - (0, (0., 0.117908)), - (0, (0., 0.112307)), - (0, (0., 0.107535)), - (0, (0., 0.106382)), + (0, (0.0, 0.122545)), + (0, (0.0, 0.12047)), + (0, (0.0, 0.118398)), + (0, (0.118398, 1.0)), + (0, (0.0, 0.117908)), + (0, (0.0, 0.112307)), + (0, (0.0, 0.107535)), + (0, (0.0, 0.106382)), ] - assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0., 1.) - assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.) == (0., 5.) + assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0.0, 1.0) + assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (0.0, 5.0) + def test_limit_min_birth_max_death(): diags = [ - (0, (2., float("inf"))), - (0, (2., float("inf"))), + (0, (2.0, float("inf"))), + (0, (2.0, float("inf"))), ] - assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2., 3.) - assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band = 4.) == (2., 6.) + assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2.0, 3.0) + assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (2.0, 6.0) + def test_limit_to_max_intervals(): diags = [ - (0, (0., float("inf"))), + (0, (0.0, float("inf"))), (0, (0.0983494, float("inf"))), - (0, (0., 0.122545)), - (0, (0., 0.12047)), - (0, (0., 0.118398)), - (0, (0.118398, 1.)), - (0, (0., 0.117908)), - (0, (0., 0.112307)), - (0, (0., 0.107535)), - (0, (0., 0.106382)), + (0, (0.0, 0.122545)), + (0, (0.0, 0.12047)), + (0, (0.0, 0.118398)), + (0, (0.118398, 1.0)), + (0, (0.0, 0.117908)), + (0, (0.0, 0.112307)), + (0, (0.0, 0.107535)), + (0, (0.0, 0.106382)), ] # check no warnings if max_intervals equals to the diagrams number with pytest.warns(None) as record: - truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(diags, 10, - key = lambda life_time: life_time[1][1] - life_time[1][0]) + truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals( + diags, 10, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) # check diagrams are not sorted assert truncated_diags == diags assert len(record) == 0 # check warning if max_intervals lower than the diagrams number with pytest.warns(UserWarning) as record: - truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(diags, 5, - key = lambda life_time: life_time[1][1] - life_time[1][0]) + truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals( + diags, 5, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) # check diagrams are truncated and sorted by life time - assert truncated_diags == [(0, (0., float("inf"))), - (0, (0.0983494, float("inf"))), - (0, (0.118398, 1.0)), - (0, (0., 0.122545)), - (0, (0., 0.12047))] + assert truncated_diags == [ + (0, (0.0, float("inf"))), + (0, (0.0983494, float("inf"))), + (0, (0.118398, 1.0)), + (0, (0.0, 0.122545)), + (0, (0.0, 0.12047)), + ] assert len(record) == 1 + def _limit_plot_persistence(function): pplot = function(persistence=[()]) assert issubclass(type(pplot), plt.axes.SubplotBase) pplot = function(persistence=[(0, float("inf"))]) assert issubclass(type(pplot), plt.axes.SubplotBase) + def test_limit_plot_persistence(): for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]: _limit_plot_persistence(function) + def _non_existing_persistence_file(function): with pytest.raises(FileNotFoundError): function(persistence_file="pouetpouettralala.toubiloubabdou") + def test_non_existing_persistence_file(): for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]: _non_existing_persistence_file(function) -- cgit v1.2.3 From 3952b61604db976f4864b85ebdafa0962766b8bc Mon Sep 17 00:00:00 2001 From: Vincent Rouvreau Date: Tue, 10 May 2022 10:00:02 +0200 Subject: Fix limit tests for plot (and warning in test) --- src/python/gudhi/persistence_graphical_tools.py | 32 +++++++++++++++------- .../test/test_persistence_graphical_tools.py | 6 +++- 2 files changed, 27 insertions(+), 11 deletions(-) (limited to 'src/python/test/test_persistence_graphical_tools.py') diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 604018d1..930df825 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -190,12 +190,17 @@ def plot_persistence_barcode( if colormap == None: colormap = plt.cm.Set1.colors - x=[birth for (dim,(birth,death)) in persistence] - y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence] - c=[colormap[dim] for (dim,(birth,death)) in persistence] + non_empty_diagram = len(persistence[0]) > 0 + if non_empty_diagram: + x=[birth for (dim,(birth,death)) in persistence] + y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence] + c=[colormap[dim] for (dim,(birth,death)) in persistence] + else: + x, y, c = [], [], [] + axes.barh(list(reversed(range(len(x)))), y, height=0.8, left=x, alpha=alpha, color=c, linewidth=0) - if legend: + if non_empty_diagram and legend: dimensions = list(set(item[0] for item in persistence)) axes.legend( handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right", @@ -321,11 +326,16 @@ def plot_persistence_diagram( # line display of equation : birth = death axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") - x=[birth for (dim,(birth,death)) in persistence] - y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence] - c=[colormap[dim] for (dim,(birth,death)) in persistence] + non_empty_diagram = len(persistence[0]) > 0 + if non_empty_diagram: + x=[birth for (dim,(birth,death)) in persistence] + y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence] + c=[colormap[dim] for (dim,(birth,death)) in persistence] + else: + x, y, c = [], [], [] + axes.scatter(x,y,alpha=alpha,color=c) - if float("inf") in (death for (dim,(birth,death)) in persistence): + if non_empty_diagram and float("inf") in (death for (dim,(birth,death)) in persistence): # infinity line and text axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha) # Infinity label @@ -337,7 +347,7 @@ def plot_persistence_diagram( axes.set_yticks(yt) axes.set_yticklabels(ytl) - if legend: + if non_empty_diagram and legend: dimensions = list(set(item[0] for item in persistence)) axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions]) @@ -482,6 +492,7 @@ def plot_persistence_density( zi = k(np.vstack([xi.flatten(), yi.flatten()])) # Make the plot img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto") + non_empty_diagram = True # IndexError on empty diagrams, ValueError on only inf death values except (IndexError, ValueError): @@ -489,6 +500,7 @@ def plot_persistence_density( birth_max = 1.0 death_min = 0.0 death_max = 1.0 + non_empty_diagram = False pass # line display of equation : birth = death @@ -504,7 +516,7 @@ def plot_persistence_density( ) ) - if legend: + if non_empty_diagram and legend: plt.colorbar(img, ax=axes) axes.set_xlabel("Birth", fontsize=fontsize) diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py index 7d9bae90..994791d7 100644 --- a/src/python/test/test_persistence_graphical_tools.py +++ b/src/python/test/test_persistence_graphical_tools.py @@ -15,7 +15,7 @@ import pytest def test_array_handler(): - diags = np.array([[1, 2], [3, 4], [5, 6]], np.float) + diags = np.array([[1, 2], [3, 4], [5, 6]], float) arr_diags = gd.persistence_graphical_tools._array_handler(diags) for idx in range(len(diags)): assert arr_diags[idx][0] == 0 @@ -98,8 +98,12 @@ def test_limit_to_max_intervals(): def _limit_plot_persistence(function): pplot = function(persistence=[()]) assert issubclass(type(pplot), plt.axes.SubplotBase) + pplot = function(persistence=[()], legend=True) + assert issubclass(type(pplot), plt.axes.SubplotBase) pplot = function(persistence=[(0, float("inf"))]) assert issubclass(type(pplot), plt.axes.SubplotBase) + pplot = function(persistence=[(0, float("inf"))], legend=True) + assert issubclass(type(pplot), plt.axes.SubplotBase) def test_limit_plot_persistence(): -- cgit v1.2.3 From 393216cc99af79526b1bb90ae4da21116884bcbd Mon Sep 17 00:00:00 2001 From: Vincent Rouvreau Date: Mon, 16 May 2022 13:47:16 +0200 Subject: code review: limit test was not respecting the one described in the documentation --- src/python/gudhi/persistence_graphical_tools.py | 51 +++++++++------------- .../test/test_persistence_graphical_tools.py | 4 +- 2 files changed, 23 insertions(+), 32 deletions(-) (limited to 'src/python/test/test_persistence_graphical_tools.py') diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 930df825..7ed11360 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -190,17 +190,13 @@ def plot_persistence_barcode( if colormap == None: colormap = plt.cm.Set1.colors - non_empty_diagram = len(persistence[0]) > 0 - if non_empty_diagram: - x=[birth for (dim,(birth,death)) in persistence] - y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence] - c=[colormap[dim] for (dim,(birth,death)) in persistence] - else: - x, y, c = [], [], [] + x=[birth for (dim,(birth,death)) in persistence] + y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence] + c=[colormap[dim] for (dim,(birth,death)) in persistence] axes.barh(list(reversed(range(len(x)))), y, height=0.8, left=x, alpha=alpha, color=c, linewidth=0) - if non_empty_diagram and legend: + if legend: dimensions = list(set(item[0] for item in persistence)) axes.legend( handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right", @@ -326,16 +322,12 @@ def plot_persistence_diagram( # line display of equation : birth = death axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k") - non_empty_diagram = len(persistence[0]) > 0 - if non_empty_diagram: - x=[birth for (dim,(birth,death)) in persistence] - y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence] - c=[colormap[dim] for (dim,(birth,death)) in persistence] - else: - x, y, c = [], [], [] + x=[birth for (dim,(birth,death)) in persistence] + y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence] + c=[colormap[dim] for (dim,(birth,death)) in persistence] axes.scatter(x,y,alpha=alpha,color=c) - if non_empty_diagram and float("inf") in (death for (dim,(birth,death)) in persistence): + if float("inf") in (death for (dim,(birth,death)) in persistence): # infinity line and text axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha) # Infinity label @@ -347,7 +339,7 @@ def plot_persistence_diagram( axes.set_yticks(yt) axes.set_yticklabels(ytl) - if non_empty_diagram and legend: + if legend: dimensions = list(set(item[0] for item in persistence)) axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions]) @@ -459,16 +451,15 @@ def plot_persistence_density( _, axes = plt.subplots(1, 1) try: - if len(persistence) > 0: - # if not read from file but given by an argument - persistence = _array_handler(persistence) - persistence_dim = np.array( - [ - (dim_interval[1][0], dim_interval[1][1]) - for dim_interval in persistence - if (dim_interval[0] == dimension) or (dimension is None) - ] - ) + # if not read from file but given by an argument + persistence = _array_handler(persistence) + persistence_dim = np.array( + [ + (dim_interval[1][0], dim_interval[1][1]) + for dim_interval in persistence + if (dim_interval[0] == dimension) or (dimension is None) + ] + ) persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])] persistence_dim = np.array( _limit_to_max_intervals( @@ -492,7 +483,7 @@ def plot_persistence_density( zi = k(np.vstack([xi.flatten(), yi.flatten()])) # Make the plot img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto") - non_empty_diagram = True + plot_success = True # IndexError on empty diagrams, ValueError on only inf death values except (IndexError, ValueError): @@ -500,7 +491,7 @@ def plot_persistence_density( birth_max = 1.0 death_min = 0.0 death_max = 1.0 - non_empty_diagram = False + plot_success = False pass # line display of equation : birth = death @@ -516,7 +507,7 @@ def plot_persistence_density( ) ) - if non_empty_diagram and legend: + if plot_success and legend: plt.colorbar(img, ax=axes) axes.set_xlabel("Birth", fontsize=fontsize) diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py index 994791d7..44decdd7 100644 --- a/src/python/test/test_persistence_graphical_tools.py +++ b/src/python/test/test_persistence_graphical_tools.py @@ -96,9 +96,9 @@ def test_limit_to_max_intervals(): def _limit_plot_persistence(function): - pplot = function(persistence=[()]) + pplot = function(persistence=[]) assert issubclass(type(pplot), plt.axes.SubplotBase) - pplot = function(persistence=[()], legend=True) + pplot = function(persistence=[], legend=True) assert issubclass(type(pplot), plt.axes.SubplotBase) pplot = function(persistence=[(0, float("inf"))]) assert issubclass(type(pplot), plt.axes.SubplotBase) -- cgit v1.2.3 From 563295694afda3dfa37cccb63a73a9f22131480e Mon Sep 17 00:00:00 2001 From: Vincent Rouvreau Date: Mon, 16 May 2022 17:14:29 +0200 Subject: code review: use isinstance(x, T) instead of issubclass(type(x),T) --- src/python/test/test_persistence_graphical_tools.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'src/python/test/test_persistence_graphical_tools.py') diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py index 44decdd7..c19836b7 100644 --- a/src/python/test/test_persistence_graphical_tools.py +++ b/src/python/test/test_persistence_graphical_tools.py @@ -97,13 +97,13 @@ def test_limit_to_max_intervals(): def _limit_plot_persistence(function): pplot = function(persistence=[]) - assert issubclass(type(pplot), plt.axes.SubplotBase) + assert isinstance(pplot, plt.axes.SubplotBase) pplot = function(persistence=[], legend=True) - assert issubclass(type(pplot), plt.axes.SubplotBase) + assert isinstance(pplot, plt.axes.SubplotBase) pplot = function(persistence=[(0, float("inf"))]) - assert issubclass(type(pplot), plt.axes.SubplotBase) + assert isinstance(pplot, plt.axes.SubplotBase) pplot = function(persistence=[(0, float("inf"))], legend=True) - assert issubclass(type(pplot), plt.axes.SubplotBase) + assert isinstance(pplot, plt.axes.SubplotBase) def test_limit_plot_persistence(): -- cgit v1.2.3