summaryrefslogtreecommitdiff
path: root/src/python/gudhi
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi')
-rw-r--r--src/python/gudhi/alpha_complex.pyx29
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py33
-rw-r--r--src/python/gudhi/representations/metrics.py17
-rw-r--r--src/python/gudhi/simplex_tree.pyx5
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py10
5 files changed, 56 insertions, 38 deletions
diff --git a/src/python/gudhi/alpha_complex.pyx b/src/python/gudhi/alpha_complex.pyx
index a356384d..ea128743 100644
--- a/src/python/gudhi/alpha_complex.pyx
+++ b/src/python/gudhi/alpha_complex.pyx
@@ -20,6 +20,7 @@ import os
from gudhi.simplex_tree cimport *
from gudhi.simplex_tree import SimplexTree
+from gudhi import read_points_from_off_file
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2016 Inria"
@@ -27,11 +28,9 @@ __license__ = "GPL v3"
cdef extern from "Alpha_complex_interface.h" namespace "Gudhi":
cdef cppclass Alpha_complex_interface "Gudhi::alpha_complex::Alpha_complex_interface":
- Alpha_complex_interface(vector[vector[double]] points, bool fast_version) nogil except +
- # bool from_file is a workaround for cython to find the correct signature
- Alpha_complex_interface(string off_file, bool fast_version, bool from_file) nogil except +
+ Alpha_complex_interface(vector[vector[double]] points, bool fast_version, bool exact_version) nogil except +
vector[double] get_point(int vertex) nogil except +
- void create_simplex_tree(Simplex_tree_interface_full_featured* simplex_tree, double max_alpha_square, bool exact_version, bool default_filtration_value) nogil except +
+ void create_simplex_tree(Simplex_tree_interface_full_featured* simplex_tree, double max_alpha_square, bool default_filtration_value) nogil except +
# AlphaComplex python interface
cdef class AlphaComplex:
@@ -54,7 +53,6 @@ cdef class AlphaComplex:
"""
cdef Alpha_complex_interface * this_ptr
- cdef bool exact
# Fake constructor that does nothing but documenting the constructor
def __init__(self, points=None, off_file='', precision='safe'):
@@ -76,21 +74,20 @@ cdef class AlphaComplex:
def __cinit__(self, points = None, off_file = '', precision = 'safe'):
assert precision in ['fast', 'safe', 'exact'], "Alpha complex precision can only be 'fast', 'safe' or 'exact'"
cdef bool fast = precision == 'fast'
- self.exact = precision == 'exact'
+ cdef bool exact = precision == 'exact'
cdef vector[vector[double]] pts
if off_file:
if os.path.isfile(off_file):
- self.this_ptr = new Alpha_complex_interface(off_file.encode('utf-8'), fast, True)
+ points = read_points_from_off_file(off_file = off_file)
else:
print("file " + off_file + " not found.")
- else:
- if points is None:
- # Empty Alpha construction
- points=[]
- pts = points
- with nogil:
- self.this_ptr = new Alpha_complex_interface(pts, fast)
+ if points is None:
+ # Empty Alpha construction
+ points=[]
+ pts = points
+ with nogil:
+ self.this_ptr = new Alpha_complex_interface(pts, fast, exact)
def __dealloc__(self):
if self.this_ptr != NULL:
@@ -102,7 +99,7 @@ cdef class AlphaComplex:
return self.this_ptr != NULL
def get_point(self, vertex):
- """This function returns the point corresponding to a given vertex.
+ """This function returns the point corresponding to a given vertex from the :class:`~gudhi.SimplexTree`.
:param vertex: The vertex.
:type vertex: int
@@ -128,5 +125,5 @@ cdef class AlphaComplex:
cdef bool compute_filtration = default_filtration_value == True
with nogil:
self.this_ptr.create_simplex_tree(<Simplex_tree_interface_full_featured*>stree_int_ptr,
- mas, self.exact, compute_filtration)
+ mas, compute_filtration)
return stree
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 6a74a6ca..c6766c70 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -11,6 +11,7 @@
from os import path
from math import isfinite
import numpy as np
+from functools import lru_cache
from gudhi.reader_utils import read_persistence_intervals_in_dimension
from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension
@@ -56,6 +57,17 @@ def _array_handler(a):
else:
return a
+@lru_cache(maxsize=1)
+def _matplotlib_can_use_tex():
+ """This function returns True if matplotlib can deal with LaTeX, False otherwise.
+ The returned value is cached.
+ """
+ try:
+ from matplotlib import checkdep_usetex
+ return checkdep_usetex(True)
+ except ImportError:
+ print("This function is not available, you may be missing matplotlib.")
+
def plot_persistence_barcode(
persistence=[],
@@ -105,9 +117,10 @@ def plot_persistence_barcode(
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
- from matplotlib import rc
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ if _matplotlib_can_use_tex():
+ from matplotlib import rc
+ plt.rc('text', usetex=True)
+ plt.rc('font', family='serif')
if persistence_file != "":
if path.isfile(persistence_file):
@@ -250,9 +263,10 @@ def plot_persistence_diagram(
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
- from matplotlib import rc
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ if _matplotlib_can_use_tex():
+ from matplotlib import rc
+ plt.rc('text', usetex=True)
+ plt.rc('font', family='serif')
if persistence_file != "":
if path.isfile(persistence_file):
@@ -422,9 +436,10 @@ def plot_persistence_density(
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from scipy.stats import kde
- from matplotlib import rc
- plt.rc('text', usetex=True)
- plt.rc('font', family='serif')
+ if _matplotlib_can_use_tex():
+ from matplotlib import rc
+ plt.rc('text', usetex=True)
+ plt.rc('font', family='serif')
if persistence_file != "":
if dimension is None:
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index cf2e0879..142ddef1 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -350,23 +350,30 @@ class PersistenceFisherDistance(BaseEstimator, TransformerMixin):
"""
return _persistence_fisher_distance(diag1, diag2, bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)
+
class WassersteinDistance(BaseEstimator, TransformerMixin):
"""
This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams.
"""
- def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01, n_jobs=None):
+
+ def __init__(self, order=1, internal_p=np.inf, mode="hera", delta=0.01, n_jobs=None):
"""
Constructor for the WassersteinDistance class.
Parameters:
- order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`.
- internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`.
- mode (str): method for computing Wasserstein distance. Either "pot" or "hera".
+ order (int): exponent for Wasserstein, default value is 1., see :func:`gudhi.wasserstein.wasserstein_distance`.
+ internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is `np.inf`, see :func:`gudhi.wasserstein.wasserstein_distance`.
+ mode (str): method for computing Wasserstein distance. Either "pot" or "hera". Default set to "hera".
delta (float): relative error 1+delta. Used only if mode == "hera".
n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_persistence_diagram_distances` for details.
"""
self.order, self.internal_p, self.mode = order, internal_p, mode
- self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein"
+ if mode == "pot":
+ self.metric = "pot_wasserstein"
+ elif mode == "hera":
+ self.metric = "hera_wasserstein"
+ else:
+ raise NameError("Unknown mode. Current available values for mode are 'hera' and 'pot'")
self.delta = delta
self.n_jobs = n_jobs
diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx
index 261f7e1b..5e032e2f 100644
--- a/src/python/gudhi/simplex_tree.pyx
+++ b/src/python/gudhi/simplex_tree.pyx
@@ -250,13 +250,12 @@ cdef class SimplexTree:
preincrement(it)
def get_skeleton(self, dimension):
- """This function returns the (simplices of the) skeleton of a maximum
- given dimension.
+ """This function returns a generator with the (simplices of the) skeleton of a maximum given dimension.
:param dimension: The skeleton dimension value.
:type dimension: int.
:returns: The (simplices of the) skeleton of a maximum dimension.
- :rtype: list of tuples(simplex, filtration)
+ :rtype: generator with tuples(simplex, filtration)
"""
cdef Simplex_tree_skeleton_iterator it = self.get_ptr().get_skeleton_iterator_begin(dimension)
cdef Simplex_tree_skeleton_iterator end = self.get_ptr().get_skeleton_iterator_end(dimension)
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 89ecab1c..b37d30bb 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -73,8 +73,8 @@ def _perstot_autodiff(X, order, internal_p):
def _perstot(X, order, internal_p, enable_autodiff):
'''
:param X: (n x 2) numpy.array (points of a given diagram).
- :param order: exponent for Wasserstein. Default value is 2.
- :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm).
+ :param order: exponent for Wasserstein.
+ :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2).
:param enable_autodiff: If X is torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
transparent to automatic differentiation.
:type enable_autodiff: bool
@@ -88,7 +88,7 @@ def _perstot(X, order, internal_p, enable_autodiff):
return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order)
-def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_autodiff=False):
+def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False):
'''
:param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
(i.e. with infinite coordinate).
@@ -96,9 +96,9 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a
:param matching: if True, computes and returns the optimal matching between X and Y, encoded as
a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to
the j-th point in Y, with the convention (-1) represents the diagonal.
- :param order: exponent for Wasserstein; Default value is 2.
+ :param order: exponent for Wasserstein; Default value is 1.
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
- Default value is 2 (Euclidean norm).
+ Default value is `np.inf`.
:param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible
with `matching=True`.