diff options
Diffstat (limited to 'src/python/gudhi')
-rw-r--r-- | src/python/gudhi/alpha_complex.pyx | 29 | ||||
-rw-r--r-- | src/python/gudhi/persistence_graphical_tools.py | 33 | ||||
-rw-r--r-- | src/python/gudhi/representations/metrics.py | 17 | ||||
-rw-r--r-- | src/python/gudhi/simplex_tree.pyx | 5 | ||||
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 10 |
5 files changed, 56 insertions, 38 deletions
diff --git a/src/python/gudhi/alpha_complex.pyx b/src/python/gudhi/alpha_complex.pyx index a356384d..ea128743 100644 --- a/src/python/gudhi/alpha_complex.pyx +++ b/src/python/gudhi/alpha_complex.pyx @@ -20,6 +20,7 @@ import os from gudhi.simplex_tree cimport * from gudhi.simplex_tree import SimplexTree +from gudhi import read_points_from_off_file __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2016 Inria" @@ -27,11 +28,9 @@ __license__ = "GPL v3" cdef extern from "Alpha_complex_interface.h" namespace "Gudhi": cdef cppclass Alpha_complex_interface "Gudhi::alpha_complex::Alpha_complex_interface": - Alpha_complex_interface(vector[vector[double]] points, bool fast_version) nogil except + - # bool from_file is a workaround for cython to find the correct signature - Alpha_complex_interface(string off_file, bool fast_version, bool from_file) nogil except + + Alpha_complex_interface(vector[vector[double]] points, bool fast_version, bool exact_version) nogil except + vector[double] get_point(int vertex) nogil except + - void create_simplex_tree(Simplex_tree_interface_full_featured* simplex_tree, double max_alpha_square, bool exact_version, bool default_filtration_value) nogil except + + void create_simplex_tree(Simplex_tree_interface_full_featured* simplex_tree, double max_alpha_square, bool default_filtration_value) nogil except + # AlphaComplex python interface cdef class AlphaComplex: @@ -54,7 +53,6 @@ cdef class AlphaComplex: """ cdef Alpha_complex_interface * this_ptr - cdef bool exact # Fake constructor that does nothing but documenting the constructor def __init__(self, points=None, off_file='', precision='safe'): @@ -76,21 +74,20 @@ cdef class AlphaComplex: def __cinit__(self, points = None, off_file = '', precision = 'safe'): assert precision in ['fast', 'safe', 'exact'], "Alpha complex precision can only be 'fast', 'safe' or 'exact'" cdef bool fast = precision == 'fast' - self.exact = precision == 'exact' + cdef bool exact = precision == 'exact' cdef vector[vector[double]] pts if off_file: if os.path.isfile(off_file): - self.this_ptr = new Alpha_complex_interface(off_file.encode('utf-8'), fast, True) + points = read_points_from_off_file(off_file = off_file) else: print("file " + off_file + " not found.") - else: - if points is None: - # Empty Alpha construction - points=[] - pts = points - with nogil: - self.this_ptr = new Alpha_complex_interface(pts, fast) + if points is None: + # Empty Alpha construction + points=[] + pts = points + with nogil: + self.this_ptr = new Alpha_complex_interface(pts, fast, exact) def __dealloc__(self): if self.this_ptr != NULL: @@ -102,7 +99,7 @@ cdef class AlphaComplex: return self.this_ptr != NULL def get_point(self, vertex): - """This function returns the point corresponding to a given vertex. + """This function returns the point corresponding to a given vertex from the :class:`~gudhi.SimplexTree`. :param vertex: The vertex. :type vertex: int @@ -128,5 +125,5 @@ cdef class AlphaComplex: cdef bool compute_filtration = default_filtration_value == True with nogil: self.this_ptr.create_simplex_tree(<Simplex_tree_interface_full_featured*>stree_int_ptr, - mas, self.exact, compute_filtration) + mas, compute_filtration) return stree diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 6a74a6ca..c6766c70 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -11,6 +11,7 @@ from os import path from math import isfinite import numpy as np +from functools import lru_cache from gudhi.reader_utils import read_persistence_intervals_in_dimension from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension @@ -56,6 +57,17 @@ def _array_handler(a): else: return a +@lru_cache(maxsize=1) +def _matplotlib_can_use_tex(): + """This function returns True if matplotlib can deal with LaTeX, False otherwise. + The returned value is cached. + """ + try: + from matplotlib import checkdep_usetex + return checkdep_usetex(True) + except ImportError: + print("This function is not available, you may be missing matplotlib.") + def plot_persistence_barcode( persistence=[], @@ -105,9 +117,10 @@ def plot_persistence_barcode( try: import matplotlib.pyplot as plt import matplotlib.patches as mpatches - from matplotlib import rc - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + if _matplotlib_can_use_tex(): + from matplotlib import rc + plt.rc('text', usetex=True) + plt.rc('font', family='serif') if persistence_file != "": if path.isfile(persistence_file): @@ -250,9 +263,10 @@ def plot_persistence_diagram( try: import matplotlib.pyplot as plt import matplotlib.patches as mpatches - from matplotlib import rc - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + if _matplotlib_can_use_tex(): + from matplotlib import rc + plt.rc('text', usetex=True) + plt.rc('font', family='serif') if persistence_file != "": if path.isfile(persistence_file): @@ -422,9 +436,10 @@ def plot_persistence_density( import matplotlib.pyplot as plt import matplotlib.patches as mpatches from scipy.stats import kde - from matplotlib import rc - plt.rc('text', usetex=True) - plt.rc('font', family='serif') + if _matplotlib_can_use_tex(): + from matplotlib import rc + plt.rc('text', usetex=True) + plt.rc('font', family='serif') if persistence_file != "": if dimension is None: diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py index cf2e0879..142ddef1 100644 --- a/src/python/gudhi/representations/metrics.py +++ b/src/python/gudhi/representations/metrics.py @@ -350,23 +350,30 @@ class PersistenceFisherDistance(BaseEstimator, TransformerMixin): """ return _persistence_fisher_distance(diag1, diag2, bandwidth=self.bandwidth, kernel_approx=self.kernel_approx) + class WassersteinDistance(BaseEstimator, TransformerMixin): """ This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams. """ - def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01, n_jobs=None): + + def __init__(self, order=1, internal_p=np.inf, mode="hera", delta=0.01, n_jobs=None): """ Constructor for the WassersteinDistance class. Parameters: - order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`. - internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`. - mode (str): method for computing Wasserstein distance. Either "pot" or "hera". + order (int): exponent for Wasserstein, default value is 1., see :func:`gudhi.wasserstein.wasserstein_distance`. + internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is `np.inf`, see :func:`gudhi.wasserstein.wasserstein_distance`. + mode (str): method for computing Wasserstein distance. Either "pot" or "hera". Default set to "hera". delta (float): relative error 1+delta. Used only if mode == "hera". n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_persistence_diagram_distances` for details. """ self.order, self.internal_p, self.mode = order, internal_p, mode - self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein" + if mode == "pot": + self.metric = "pot_wasserstein" + elif mode == "hera": + self.metric = "hera_wasserstein" + else: + raise NameError("Unknown mode. Current available values for mode are 'hera' and 'pot'") self.delta = delta self.n_jobs = n_jobs diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 261f7e1b..5e032e2f 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -250,13 +250,12 @@ cdef class SimplexTree: preincrement(it) def get_skeleton(self, dimension): - """This function returns the (simplices of the) skeleton of a maximum - given dimension. + """This function returns a generator with the (simplices of the) skeleton of a maximum given dimension. :param dimension: The skeleton dimension value. :type dimension: int. :returns: The (simplices of the) skeleton of a maximum dimension. - :rtype: list of tuples(simplex, filtration) + :rtype: generator with tuples(simplex, filtration) """ cdef Simplex_tree_skeleton_iterator it = self.get_ptr().get_skeleton_iterator_begin(dimension) cdef Simplex_tree_skeleton_iterator end = self.get_ptr().get_skeleton_iterator_end(dimension) diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 89ecab1c..b37d30bb 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -73,8 +73,8 @@ def _perstot_autodiff(X, order, internal_p): def _perstot(X, order, internal_p, enable_autodiff): ''' :param X: (n x 2) numpy.array (points of a given diagram). - :param order: exponent for Wasserstein. Default value is 2. - :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). + :param order: exponent for Wasserstein. + :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2). :param enable_autodiff: If X is torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation transparent to automatic differentiation. :type enable_autodiff: bool @@ -88,7 +88,7 @@ def _perstot(X, order, internal_p, enable_autodiff): return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order) -def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_autodiff=False): +def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False): ''' :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate). @@ -96,9 +96,9 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a :param matching: if True, computes and returns the optimal matching between X and Y, encoded as a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to the j-th point in Y, with the convention (-1) represents the diagonal. - :param order: exponent for Wasserstein; Default value is 2. + :param order: exponent for Wasserstein; Default value is 1. :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); - Default value is 2 (Euclidean norm). + Default value is `np.inf`. :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible with `matching=True`. |