summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
Diffstat (limited to 'src/python')
-rw-r--r--src/python/CMakeLists.txt16
-rw-r--r--src/python/doc/alpha_complex_user.rst4
-rw-r--r--src/python/doc/nerve_gic_complex_user.rst2
-rw-r--r--src/python/doc/persistence_graphical_tools_user.rst2
-rwxr-xr-xsrc/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py2
-rw-r--r--src/python/gudhi/__init__.py.in4
-rw-r--r--src/python/gudhi/hera/wasserstein.cc2
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py349
-rw-r--r--src/python/gudhi/simplex_tree.pyx15
-rw-r--r--src/python/gudhi/wasserstein/barycenter.py6
-rw-r--r--src/python/include/Simplex_tree_interface.h8
-rw-r--r--src/python/test/test_persistence_graphical_tools.py121
-rwxr-xr-xsrc/python/test/test_simplex_tree.py19
-rwxr-xr-xsrc/python/test/test_subsampling.py4
14 files changed, 326 insertions, 228 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index e31af02e..af0b6115 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -148,10 +148,6 @@ if(PYTHONINTERP_FOUND)
add_gudhi_debug_info("Eigen3 version ${EIGEN3_VERSION}")
# No problem, even if no CGAL found
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DCGAL_EIGEN3_ENABLED', ")
- set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DGUDHI_USE_EIGEN3', ")
- set(GUDHI_USE_EIGEN3 "True")
- else (EIGEN3_FOUND)
- set(GUDHI_USE_EIGEN3 "False")
endif (EIGEN3_FOUND)
set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'off_reader', ")
@@ -333,9 +329,9 @@ if(PYTHONINTERP_FOUND)
if(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0)
set (GUDHI_SPHINX_MESSAGE "Generating API documentation with Sphinx in ${CMAKE_CURRENT_BINARY_DIR}/sphinx/")
# User warning - Sphinx is a static pages generator, and configured to work fine with user_version
- # Images and biblio warnings because not found on developper version
+ # Images and biblio warnings because not found on developer version
if (GUDHI_PYTHON_PATH STREQUAL "src/python")
- set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developper version. Images and biblio will miss")
+ set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developer version. Images and biblio will miss")
endif()
# sphinx target requires gudhi.so, because conf.py reads gudhi version from it
add_custom_target(sphinx
@@ -488,7 +484,7 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_euclidean_witness_complex)
# Datasets generators
- add_gudhi_py_test(test_datasets_generators) # TODO separate full python datasets generators in another test file independant from CGAL ?
+ add_gudhi_py_test(test_datasets_generators) # TODO separate full python datasets generators in another test file independent from CGAL ?
endif (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
@@ -595,6 +591,10 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_dtm_rips_complex)
endif()
+ # persistence graphical tools
+ if(MATPLOTLIB_FOUND)
+ add_gudhi_py_test(test_persistence_graphical_tools)
+ endif()
# Set missing or not modules
set(GUDHI_MODULES ${GUDHI_MODULES} "python" CACHE INTERNAL "GUDHI_MODULES")
@@ -605,4 +605,4 @@ if(PYTHONINTERP_FOUND)
else(PYTHONINTERP_FOUND)
message("++ Python module will not be compiled because no Python interpreter was found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python" CACHE INTERNAL "GUDHI_MISSING_MODULES")
-endif(PYTHONINTERP_FOUND) \ No newline at end of file
+endif(PYTHONINTERP_FOUND)
diff --git a/src/python/doc/alpha_complex_user.rst b/src/python/doc/alpha_complex_user.rst
index b060c86e..9e67d38a 100644
--- a/src/python/doc/alpha_complex_user.rst
+++ b/src/python/doc/alpha_complex_user.rst
@@ -178,11 +178,11 @@ Weighted version
^^^^^^^^^^^^^^^^
A weighted version for Alpha complex is available. It is like a usual Alpha complex, but based on a
-`CGAL regular triangulation <https://doc.cgal.org/latest/Triangulation/index.html#title20>`_.
+`CGAL regular triangulation <https://doc.cgal.org/latest/Triangulation/index.html#TriangulationSecRT>`_.
This example builds the weighted alpha-complex of a small molecule, where atoms have different sizes.
It is taken from
-`CGAL 3d weighted alpha shapes <https://doc.cgal.org/latest/Alpha_shapes_3/index.html#title13>`_.
+`CGAL 3d weighted alpha shapes <https://doc.cgal.org/latest/Alpha_shapes_3/index.html#AlphaShape_3DExampleforWeightedAlphaShapes>`_.
Then, it is asked to display information about the alpha complex.
diff --git a/src/python/doc/nerve_gic_complex_user.rst b/src/python/doc/nerve_gic_complex_user.rst
index 0b820abf..8633cadb 100644
--- a/src/python/doc/nerve_gic_complex_user.rst
+++ b/src/python/doc/nerve_gic_complex_user.rst
@@ -12,7 +12,7 @@ Definition
Visualizations of the simplicial complexes can be done with either
neato (from `graphviz <http://www.graphviz.org/>`_),
`geomview <http://www.geomview.org/>`_,
-`KeplerMapper <https://github.com/MLWave/kepler-mapper>`_.
+`KeplerMapper <https://github.com/scikit-tda/kepler-mapper>`_.
Input point clouds are assumed to be OFF files (cf. `OFF file format <fileformats.html#off-file-format>`_).
Covers
diff --git a/src/python/doc/persistence_graphical_tools_user.rst b/src/python/doc/persistence_graphical_tools_user.rst
index d95b9d2b..e1d28c71 100644
--- a/src/python/doc/persistence_graphical_tools_user.rst
+++ b/src/python/doc/persistence_graphical_tools_user.rst
@@ -60,7 +60,7 @@ of shape (N x 2) encoding a persistence diagram (in a given dimension).
import matplotlib.pyplot as plt
import gudhi
import numpy as np
- d = np.array([[0, 1], [1, 2], [1, np.inf]])
+ d = np.array([[0., 1.], [1., 2.], [1., np.inf]])
gudhi.plot_persistence_diagram(d)
plt.show()
diff --git a/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py b/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py
index ea2eb7e1..0b35dbc5 100755
--- a/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py
+++ b/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py
@@ -40,7 +40,7 @@ parser.add_argument(
args = parser.parse_args()
if not (-1.0 < args.min_edge_correlation < 1.0):
- print("Wrong value of the treshold corelation (should be between -1 and 1).")
+ print("Wrong value of the threshold corelation (should be between -1 and 1).")
sys.exit(1)
print("#####################################################################")
diff --git a/src/python/gudhi/__init__.py.in b/src/python/gudhi/__init__.py.in
index 3043201a..79e12fbc 100644
--- a/src/python/gudhi/__init__.py.in
+++ b/src/python/gudhi/__init__.py.in
@@ -23,10 +23,6 @@ __all__ = [@GUDHI_PYTHON_MODULES@ @GUDHI_PYTHON_MODULES_EXTRA@]
__available_modules = ''
__missing_modules = ''
-# For unitary tests purpose
-# could use "if 'collapse_edges' in gudhi.__all__" when collapse edges will have a python module
-__GUDHI_USE_EIGEN3 = @GUDHI_USE_EIGEN3@
-
# Try to import * from gudhi.__module_name for default modules.
# Extra modules require an explicit import by the user (mostly because of
# unusual dependencies, but also to avoid cluttering namespace gudhi and
diff --git a/src/python/gudhi/hera/wasserstein.cc b/src/python/gudhi/hera/wasserstein.cc
index 1a21f02f..fa0cf8aa 100644
--- a/src/python/gudhi/hera/wasserstein.cc
+++ b/src/python/gudhi/hera/wasserstein.cc
@@ -29,7 +29,7 @@ double wasserstein_distance(
if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>();
params.internal_p = internal_p;
params.delta = delta;
- // The extra parameters are purposedly not exposed for now.
+ // The extra parameters are purposely not exposed for now.
return hera::wasserstein_dist(diag1, diag2, params);
}
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 848dc03e..21275cdd 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -12,6 +12,9 @@ from os import path
from math import isfinite
import numpy as np
from functools import lru_cache
+import warnings
+import errno
+import os
from gudhi.reader_utils import read_persistence_intervals_in_dimension
from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension
@@ -22,6 +25,7 @@ __license__ = "MIT"
_gudhi_matplotlib_use_tex = True
+
def __min_birth_max_death(persistence, band=0.0):
"""This function returns (min_birth, max_death) from the persistence.
@@ -44,20 +48,46 @@ def __min_birth_max_death(persistence, band=0.0):
min_birth = float(interval[1][0])
if band > 0.0:
max_death += band
+ # can happen if only points at inf death
+ if min_birth == max_death:
+ max_death = max_death + 1.0
return (min_birth, max_death)
def _array_handler(a):
- '''
+ """
:param a: if array, assumes it is a (n x 2) np.array and return a
persistence-compatible list (padding with 0), so that the
plot can be performed seamlessly.
- '''
- if isinstance(a[0][1], np.float64) or isinstance(a[0][1], float):
+ """
+ if isinstance(a[0][1], (np.floating, float)):
return [[0, x] for x in a]
else:
return a
+
+def _limit_to_max_intervals(persistence, max_intervals, key):
+ """This function returns truncated persistence if length is bigger than max_intervals.
+ :param persistence: Persistence intervals values list. Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
+ :param max_intervals: maximal number of intervals to display.
+ Selected intervals are those with the longest life time. Set it
+ to 0 to see all. Default value is 1000.
+ :type max_intervals: int.
+ :param key: key function for sort algorithm.
+ :type key: function or lambda.
+ """
+ if max_intervals > 0 and max_intervals < len(persistence):
+ warnings.warn(
+ "There are %s intervals given as input, whereas max_intervals is set to %s."
+ % (len(persistence), max_intervals)
+ )
+ # Sort by life time, then takes only the max_intervals elements
+ return sorted(persistence, key=key, reverse=True)[:max_intervals]
+ else:
+ return persistence
+
+
@lru_cache(maxsize=1)
def _matplotlib_can_use_tex():
"""This function returns True if matplotlib can deal with LaTeX, False otherwise.
@@ -65,17 +95,17 @@ def _matplotlib_can_use_tex():
"""
try:
from matplotlib import checkdep_usetex
+
return checkdep_usetex(True)
- except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
def plot_persistence_barcode(
persistence=[],
persistence_file="",
alpha=0.6,
- max_intervals=1000,
- max_barcodes=1000,
+ max_intervals=20000,
inf_delta=0.1,
legend=False,
colormap=None,
@@ -97,7 +127,7 @@ def plot_persistence_barcode(
:type alpha: float.
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
- to 0 to see all. Default value is 1000.
+ to 0 to see all. Default value is 20000.
:type max_intervals: int.
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x
inf_delta)` above :code:`max_death` value. A reasonable value is
@@ -119,99 +149,68 @@ def plot_persistence_barcode(
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import rc
+
if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
else:
- plt.rc('text', usetex=False)
- plt.rc('font', family='DejaVu Sans')
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
- diag = read_persistence_intervals_grouped_by_dimension(
- persistence_file=persistence_file
- )
+ diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
for key in diag.keys():
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
else:
- print("file " + persistence_file + " not found.")
- return None
-
- persistence = _array_handler(persistence)
-
- if max_barcodes != 1000:
- print("Deprecated parameter. It has been replaced by max_intervals")
- max_intervals = max_barcodes
-
- if max_intervals > 0 and max_intervals < len(persistence):
- # Sort by life time, then takes only the max_intervals elements
- persistence = sorted(
- persistence,
- key=lambda life_time: life_time[1][1] - life_time[1][0],
- reverse=True,
- )[:max_intervals]
-
- if colormap == None:
- colormap = plt.cm.Set1.colors
- if axes == None:
- fig, axes = plt.subplots(1, 1)
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
- persistence = sorted(persistence, key=lambda birth: birth[1][0])
+ try:
+ persistence = _array_handler(persistence)
+ persistence = _limit_to_max_intervals(
+ persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ (min_birth, max_death) = __min_birth_max_death(persistence)
+ persistence = sorted(persistence, key=lambda birth: birth[1][0])
+ except IndexError:
+ min_birth, max_death = 0.0, 1.0
+ pass
- (min_birth, max_death) = __min_birth_max_death(persistence)
- ind = 0
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for bar code to be more
# readable
infinity = max_death + delta
axis_start = min_birth - delta
- # Draw horizontal bars in loop
- for interval in reversed(persistence):
- if float(interval[1][1]) != float("inf"):
- # Finite death case
- axes.barh(
- ind,
- (interval[1][1] - interval[1][0]),
- height=0.8,
- left=interval[1][0],
- alpha=alpha,
- color=colormap[interval[0]],
- linewidth=0,
- )
- else:
- # Infinite death case for diagram to be nicer
- axes.barh(
- ind,
- (infinity - interval[1][0]),
- height=0.8,
- left=interval[1][0],
- alpha=alpha,
- color=colormap[interval[0]],
- linewidth=0,
- )
- ind = ind + 1
+
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
+
+ x=[birth for (dim,(birth,death)) in persistence]
+ y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence]
+ c=[colormap[dim] for (dim,(birth,death)) in persistence]
+
+ axes.barh(list(reversed(range(len(x)))), y, height=0.8, left=x, alpha=alpha, color=c, linewidth=0)
if legend:
dimensions = list(set(item[0] for item in persistence))
axes.legend(
- handles=[
- mpatches.Patch(color=colormap[dim], label=str(dim))
- for dim in dimensions
- ],
- loc="lower right",
+ handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right",
)
axes.set_title("Persistence barcode", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
- axes.axis([axis_start, infinity, 0, ind])
+ if len(x) != 0:
+ axes.axis([axis_start, infinity, 0, len(x)])
return axes
- except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
def plot_persistence_diagram(
@@ -219,14 +218,13 @@ def plot_persistence_diagram(
persistence_file="",
alpha=0.6,
band=0.0,
- max_intervals=1000,
- max_plots=1000,
+ max_intervals=1000000,
inf_delta=0.1,
legend=False,
colormap=None,
axes=None,
fontsize=16,
- greyblock=True
+ greyblock=True,
):
"""This function plots the persistence diagram from persistence values
list, a np.array of shape (N x 2) representing a diagram in a single
@@ -244,7 +242,7 @@ def plot_persistence_diagram(
:type band: float.
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
- to 0 to see all. Default value is 1000.
+ to 0 to see all. Default value is 1000000.
:type max_intervals: int.
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x
inf_delta)` above :code:`max_death` value. A reasonable value is
@@ -268,47 +266,35 @@ def plot_persistence_diagram(
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import rc
+
if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
else:
- plt.rc('text', usetex=False)
- plt.rc('font', family='DejaVu Sans')
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
persistence = []
- diag = read_persistence_intervals_grouped_by_dimension(
- persistence_file=persistence_file
- )
+ diag = read_persistence_intervals_grouped_by_dimension(persistence_file=persistence_file)
for key in diag.keys():
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
else:
- print("file " + persistence_file + " not found.")
- return None
-
- persistence = _array_handler(persistence)
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
- if max_plots != 1000:
- print("Deprecated parameter. It has been replaced by max_intervals")
- max_intervals = max_plots
-
- if max_intervals > 0 and max_intervals < len(persistence):
- # Sort by life time, then takes only the max_intervals elements
- persistence = sorted(
- persistence,
- key=lambda life_time: life_time[1][1] - life_time[1][0],
- reverse=True,
- )[:max_intervals]
-
- if colormap == None:
- colormap = plt.cm.Set1.colors
- if axes == None:
- fig, axes = plt.subplots(1, 1)
+ try:
+ persistence = _array_handler(persistence)
+ persistence = _limit_to_max_intervals(
+ persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ min_birth, max_death = __min_birth_max_death(persistence, band)
+ except IndexError:
+ min_birth, max_death = 0.0, 1.0
+ pass
- (min_birth, max_death) = __min_birth_max_death(persistence, band)
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for diagram to be more
# readable
@@ -316,61 +302,56 @@ def plot_persistence_diagram(
axis_end = max_death + delta / 2
axis_start = min_birth - delta
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
+ if colormap == None:
+ colormap = plt.cm.Set1.colors
# bootstrap band
if band > 0.0:
x = np.linspace(axis_start, infinity, 1000)
axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
# lower diag patch
if greyblock:
- axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey'))
- # Draw points in loop
- pts_at_infty = False # Records presence of pts at infty
- for interval in reversed(persistence):
- if float(interval[1][1]) != float("inf"):
- # Finite death case
- axes.scatter(
- interval[1][0],
- interval[1][1],
- alpha=alpha,
- color=colormap[interval[0]],
+ axes.add_patch(
+ mpatches.Polygon(
+ [[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]],
+ fill=True,
+ color="lightgrey",
)
- else:
- pts_at_infty = True
- # Infinite death case for diagram to be nicer
- axes.scatter(
- interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]
- )
- if pts_at_infty:
+ )
+ # line display of equation : birth = death
+ axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
+
+ x=[birth for (dim,(birth,death)) in persistence]
+ y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence]
+ c=[colormap[dim] for (dim,(birth,death)) in persistence]
+
+ axes.scatter(x,y,alpha=alpha,color=c)
+ if float("inf") in (death for (dim,(birth,death)) in persistence):
# infinity line and text
- axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha)
# Infinity label
yt = axes.get_yticks()
- yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity
+ yt = yt[np.where(yt < axis_end)] # to avoid plotting ticklabel higher than infinity
yt = np.append(yt, infinity)
ytl = ["%.3f" % e for e in yt] # to avoid float precision error
- ytl[-1] = r'$+\infty$'
+ ytl[-1] = r"$+\infty$"
axes.set_yticks(yt)
axes.set_yticklabels(ytl)
if legend:
dimensions = list(set(item[0] for item in persistence))
- axes.legend(
- handles=[
- mpatches.Patch(color=colormap[dim], label=str(dim))
- for dim in dimensions
- ]
- )
+ axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions])
axes.set_xlabel("Birth", fontsize=fontsize)
axes.set_ylabel("Death", fontsize=fontsize)
axes.set_title("Persistence diagram", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
- axes.axis([axis_start, axis_end, axis_start, infinity + delta/2])
+ axes.axis([axis_start, axis_end, axis_start, infinity + delta / 2])
return axes
- except ImportError:
- print("This function is not available, you may be missing matplotlib.")
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
def plot_persistence_density(
@@ -384,7 +365,7 @@ def plot_persistence_density(
legend=False,
axes=None,
fontsize=16,
- greyblock=False
+ greyblock=False,
):
"""This function plots the persistence density from persistence
values list, np.array of shape (N x 2) representing a diagram
@@ -444,12 +425,13 @@ def plot_persistence_density(
import matplotlib.patches as mpatches
from scipy.stats import kde
from matplotlib import rc
+
if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ plt.rc("text", usetex=True)
+ plt.rc("font", family="serif")
else:
- plt.rc('text', usetex=False)
- plt.rc('font', family='DejaVu Sans')
+ plt.rc("text", usetex=False)
+ plt.rc("font", family="DejaVu Sans")
if persistence_file != "":
if dimension is None:
@@ -460,10 +442,16 @@ def plot_persistence_density(
persistence_file=persistence_file, only_this_dim=dimension
)
else:
- print("file " + persistence_file + " not found.")
- return None
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)
+
+ # default cmap value cannot be done at argument definition level as matplotlib is not yet defined.
+ if cmap is None:
+ cmap = plt.cm.hot_r
+ if axes == None:
+ _, axes = plt.subplots(1, 1)
- if len(persistence) > 0:
+ try:
+ # if not read from file but given by an argument
persistence = _array_handler(persistence)
persistence_dim = np.array(
[
@@ -472,47 +460,54 @@ def plot_persistence_density(
if (dim_interval[0] == dimension) or (dimension is None)
]
)
-
- persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
- if max_intervals > 0 and max_intervals < len(persistence_dim):
- # Sort by life time, then takes only the max_intervals elements
+ persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
persistence_dim = np.array(
- sorted(
- persistence_dim,
- key=lambda life_time: life_time[1] - life_time[0],
- reverse=True,
- )[:max_intervals]
+ _limit_to_max_intervals(
+ persistence_dim, max_intervals, key=lambda life_time: life_time[1] - life_time[0]
+ )
)
- # Set as numpy array birth and death (remove undefined values - inf and NaN)
- birth = persistence_dim[:, 0]
- death = persistence_dim[:, 1]
-
- # default cmap value cannot be done at argument definition level as matplotlib is not yet defined.
- if cmap is None:
- cmap = plt.cm.hot_r
- if axes == None:
- fig, axes = plt.subplots(1, 1)
+ # Set as numpy array birth and death (remove undefined values - inf and NaN)
+ birth = persistence_dim[:, 0]
+ death = persistence_dim[:, 1]
+ birth_min = birth.min()
+ birth_max = birth.max()
+ death_min = death.min()
+ death_max = death.max()
+
+ # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
+ k = kde.gaussian_kde([birth, death], bw_method=bw_method)
+ xi, yi = np.mgrid[
+ birth_min : birth_max : nbins * 1j, death_min : death_max : nbins * 1j,
+ ]
+ zi = k(np.vstack([xi.flatten(), yi.flatten()]))
+ # Make the plot
+ img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading="auto")
+ plot_success = True
+
+ # IndexError on empty diagrams, ValueError on only inf death values
+ except (IndexError, ValueError):
+ birth_min = 0.0
+ birth_max = 1.0
+ death_min = 0.0
+ death_max = 1.0
+ plot_success = False
+ pass
# line display of equation : birth = death
- x = np.linspace(death.min(), birth.max(), 1000)
+ x = np.linspace(death_min, birth_max, 1000)
axes.plot(x, x, color="k", linewidth=1.0)
- # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
- k = kde.gaussian_kde([birth, death], bw_method=bw_method)
- xi, yi = np.mgrid[
- birth.min() : birth.max() : nbins * 1j,
- death.min() : death.max() : nbins * 1j,
- ]
- zi = k(np.vstack([xi.flatten(), yi.flatten()]))
-
- # Make the plot
- img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap)
-
if greyblock:
- axes.add_patch(mpatches.Polygon([[birth.min(), birth.min()], [death.max(), birth.min()], [death.max(), death.max()]], fill=True, color='lightgrey'))
+ axes.add_patch(
+ mpatches.Polygon(
+ [[birth_min, birth_min], [death_max, birth_min], [death_max, death_max]],
+ fill=True,
+ color="lightgrey",
+ )
+ )
- if legend:
+ if plot_success and legend:
plt.colorbar(img, ax=axes)
axes.set_xlabel("Birth", fontsize=fontsize)
@@ -521,7 +516,5 @@ def plot_persistence_density(
return axes
- except ImportError:
- print(
- "This function is not available, you may be missing matplotlib and/or scipy."
- )
+ except ImportError as import_error:
+ warnings.warn(f"This function is not available.\nModuleNotFoundError: No module named '{import_error.name}'.")
diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx
index 97c7d2c8..521a7763 100644
--- a/src/python/gudhi/simplex_tree.pyx
+++ b/src/python/gudhi/simplex_tree.pyx
@@ -675,18 +675,17 @@ cdef class SimplexTree:
return (normal0, normals, infinite0, infinites)
def collapse_edges(self, nb_iterations = 1):
- """Assuming the simplex tree is a 1-skeleton graph, this method collapse edges (simplices of higher dimension
- are ignored) and resets the simplex tree from the remaining edges.
- A good candidate is to build a simplex tree on top of a :class:`~gudhi.RipsComplex` of dimension 1 before
- collapsing edges
+ """Assuming the complex is a graph (simplices of higher dimension are ignored), this method implicitly
+ interprets it as the 1-skeleton of a flag complex, and replaces it with another (smaller) graph whose
+ expansion has the same persistent homology, using a technique known as edge collapses
+ (see :cite:`edgecollapsearxiv`).
+
+ A natural application is to get a simplex tree of dimension 1 from :class:`~gudhi.RipsComplex`,
+ then collapse edges, perform :meth:`expansion()` and finally compute persistence
(cf. :download:`rips_complex_edge_collapse_example.py <../example/rips_complex_edge_collapse_example.py>`).
- For implementation details, please refer to :cite:`edgecollapsesocg2020`.
:param nb_iterations: The number of edge collapse iterations to perform. Default is 1.
:type nb_iterations: int
-
- :note: collapse_edges method requires `Eigen <installation.html#eigen>`_ >= 3.1.0 and an exception is thrown
- if this method is not available.
"""
# Backup old pointer
cdef Simplex_tree_interface_full_featured* ptr = self.get_ptr()
diff --git a/src/python/gudhi/wasserstein/barycenter.py b/src/python/gudhi/wasserstein/barycenter.py
index d67bcde7..bb6e641e 100644
--- a/src/python/gudhi/wasserstein/barycenter.py
+++ b/src/python/gudhi/wasserstein/barycenter.py
@@ -37,7 +37,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
:param init: The initial value for barycenter estimate.
If ``None``, init is made on a random diagram from the dataset.
Otherwise, it can be an ``int`` (then initialization is made on ``pdiagset[init]``)
- or a `(n x 2)` ``numpy.array`` enconding a persistence diagram with `n` points.
+ or a `(n x 2)` ``numpy.array`` encoding a persistence diagram with `n` points.
:type init: ``int``, or (n x 2) ``np.array``
:param verbose: if ``True``, returns additional information about the barycenter.
:type verbose: boolean
@@ -45,7 +45,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
(local minimum of the energy function).
If ``pdiagset`` is empty, returns ``None``.
If verbose, returns a couple ``(Y, log)`` where ``Y`` is the barycenter estimate,
- and ``log`` is a ``dict`` that contains additional informations:
+ and ``log`` is a ``dict`` that contains additional information:
- `"groupings"`, a list of list of pairs ``(i,j)``. Namely, ``G[k] = [...(i, j)...]``, where ``(i,j)`` indicates that `pdiagset[k][i]`` is matched to ``Y[j]`` if ``i = -1`` or ``j = -1``, it means they represent the diagonal.
@@ -73,7 +73,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
nb_iter = 0
- converged = False # stoping criterion
+ converged = False # stopping criterion
while not converged:
nb_iter += 1
K = len(Y) # current nb of points in Y (some might be on diagonal)
diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h
index aec29f9b..3848c5ad 100644
--- a/src/python/include/Simplex_tree_interface.h
+++ b/src/python/include/Simplex_tree_interface.h
@@ -15,9 +15,7 @@
#include <gudhi/distance_functions.h>
#include <gudhi/Simplex_tree.h>
#include <gudhi/Points_off_io.h>
-#ifdef GUDHI_USE_EIGEN3
#include <gudhi/Flag_complex_edge_collapser.h>
-#endif
#include <iostream>
#include <vector>
@@ -135,7 +133,6 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
}
Simplex_tree_interface* collapse_edges(int nb_collapse_iteration) {
-#ifdef GUDHI_USE_EIGEN3
using Filtered_edge = std::tuple<Vertex_handle, Vertex_handle, Filtration_value>;
std::vector<Filtered_edge> edges;
for (Simplex_handle sh : Base::skeleton_simplex_range(1)) {
@@ -149,7 +146,7 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
}
for (int iteration = 0; iteration < nb_collapse_iteration; iteration++) {
- edges = Gudhi::collapse::flag_complex_collapse_edges(edges);
+ edges = Gudhi::collapse::flag_complex_collapse_edges(std::move(edges));
}
Simplex_tree_interface* collapsed_stree_ptr = new Simplex_tree_interface();
// Copy the original 0-skeleton
@@ -161,9 +158,6 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
collapsed_stree_ptr->insert({std::get<0>(remaining_edge), std::get<1>(remaining_edge)}, std::get<2>(remaining_edge));
}
return collapsed_stree_ptr;
-#else
- throw std::runtime_error("Unable to collapse edges as it requires Eigen3 >= 3.1.0.");
-#endif
}
void expansion_with_blockers_callback(int dimension, blocker_func_t user_func, void *user_data) {
diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py
new file mode 100644
index 00000000..c19836b7
--- /dev/null
+++ b/src/python/test/test_persistence_graphical_tools.py
@@ -0,0 +1,121 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+import gudhi as gd
+import numpy as np
+import matplotlib as plt
+import pytest
+
+
+def test_array_handler():
+ diags = np.array([[1, 2], [3, 4], [5, 6]], float)
+ arr_diags = gd.persistence_graphical_tools._array_handler(diags)
+ for idx in range(len(diags)):
+ assert arr_diags[idx][0] == 0
+ np.testing.assert_array_equal(arr_diags[idx][1], diags[idx])
+
+ diags = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]
+ arr_diags = gd.persistence_graphical_tools._array_handler(diags)
+ for idx in range(len(diags)):
+ assert arr_diags[idx][0] == 0
+ assert arr_diags[idx][1] == diags[idx]
+
+ diags = [(0, (1.0, 2.0)), (0, (3.0, 4.0)), (0, (5.0, 6.0))]
+ assert gd.persistence_graphical_tools._array_handler(diags) == diags
+
+
+def test_min_birth_max_death():
+ diags = [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ (0, (0.0, 0.118398)),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.117908)),
+ (0, (0.0, 0.112307)),
+ (0, (0.0, 0.107535)),
+ (0, (0.0, 0.106382)),
+ ]
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0.0, 1.0)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (0.0, 5.0)
+
+
+def test_limit_min_birth_max_death():
+ diags = [
+ (0, (2.0, float("inf"))),
+ (0, (2.0, float("inf"))),
+ ]
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2.0, 3.0)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (2.0, 6.0)
+
+
+def test_limit_to_max_intervals():
+ diags = [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ (0, (0.0, 0.118398)),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.117908)),
+ (0, (0.0, 0.112307)),
+ (0, (0.0, 0.107535)),
+ (0, (0.0, 0.106382)),
+ ]
+ # check no warnings if max_intervals equals to the diagrams number
+ with pytest.warns(None) as record:
+ truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(
+ diags, 10, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ # check diagrams are not sorted
+ assert truncated_diags == diags
+ assert len(record) == 0
+
+ # check warning if max_intervals lower than the diagrams number
+ with pytest.warns(UserWarning) as record:
+ truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(
+ diags, 5, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ # check diagrams are truncated and sorted by life time
+ assert truncated_diags == [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ ]
+ assert len(record) == 1
+
+
+def _limit_plot_persistence(function):
+ pplot = function(persistence=[])
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[], legend=True)
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[(0, float("inf"))])
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[(0, float("inf"))], legend=True)
+ assert isinstance(pplot, plt.axes.SubplotBase)
+
+
+def test_limit_plot_persistence():
+ for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]:
+ _limit_plot_persistence(function)
+
+
+def _non_existing_persistence_file(function):
+ with pytest.raises(FileNotFoundError):
+ function(persistence_file="pouetpouettralala.toubiloubabdou")
+
+
+def test_non_existing_persistence_file():
+ for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]:
+ _non_existing_persistence_file(function)
diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py
index 7abb4366..54bafed5 100755
--- a/src/python/test/test_simplex_tree.py
+++ b/src/python/test/test_simplex_tree.py
@@ -8,7 +8,7 @@
- YYYY/MM Author: Description of the modification
"""
-from gudhi import SimplexTree, __GUDHI_USE_EIGEN3
+from gudhi import SimplexTree
import numpy as np
import pytest
@@ -358,16 +358,11 @@ def test_collapse_edges():
assert st.num_simplices() == 10
- if __GUDHI_USE_EIGEN3:
- st.collapse_edges()
- assert st.num_simplices() == 9
- assert st.find([1, 3]) == False
- for simplex in st.get_skeleton(0):
- assert simplex[1] == 1.
- else:
- # If no Eigen3, collapse_edges throws an exception
- with pytest.raises(RuntimeError):
- st.collapse_edges()
+ st.collapse_edges()
+ assert st.num_simplices() == 9
+ assert st.find([0, 2]) == False # [1, 3] would be fine as well
+ for simplex in st.get_skeleton(0):
+ assert simplex[1] == 1.
def test_reset_filtration():
st = SimplexTree()
@@ -537,7 +532,7 @@ def test_expansion_with_blocker():
def blocker(simplex):
try:
- # Block all simplices that countains vertex 6
+ # Block all simplices that contain vertex 6
simplex.index(6)
print(simplex, ' is blocked')
return True
diff --git a/src/python/test/test_subsampling.py b/src/python/test/test_subsampling.py
index 4019852e..3431f372 100755
--- a/src/python/test/test_subsampling.py
+++ b/src/python/test/test_subsampling.py
@@ -91,7 +91,7 @@ def test_simple_choose_n_farthest_points_randomed():
assert gudhi.choose_n_farthest_points(points=[], nb_points=1) == []
assert gudhi.choose_n_farthest_points(points=point_set, nb_points=0) == []
- # Go furter than point set on purpose
+ # Go further than point set on purpose
for iter in range(1, 10):
sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=iter)
for sub in sub_set:
@@ -117,7 +117,7 @@ def test_simple_pick_n_random_points():
assert gudhi.pick_n_random_points(points=[], nb_points=1) == []
assert gudhi.pick_n_random_points(points=point_set, nb_points=0) == []
- # Go furter than point set on purpose
+ # Go further than point set on purpose
for iter in range(1, 10):
sub_set = gudhi.pick_n_random_points(points=point_set, nb_points=iter)
for sub in sub_set: