summaryrefslogtreecommitdiff
path: root/src/python/gudhi
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2020-08-18 10:55:42 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2020-08-18 10:55:42 +0200
commita1cd7e9ead030654a1fdb6cfd50408103c458529 (patch)
tree9786156bfb00d5b4f85dda2458b087d60d1bc1a8 /src/python/gudhi
parent85eec1ba750d56b66e3739dc486c6205f49fb31e (diff)
parent4737aaeb36a4ff3b27d7bcbb374911197ed09e5a (diff)
Merge master and resolve conflicts
Diffstat (limited to 'src/python/gudhi')
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py22
-rw-r--r--src/python/gudhi/point_cloud/knn.py2
-rw-r--r--src/python/gudhi/representations/vector_methods.py144
-rw-r--r--src/python/gudhi/simplex_tree.pxd1
-rw-r--r--src/python/gudhi/simplex_tree.pyx66
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py6
6 files changed, 205 insertions, 36 deletions
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index c6766c70..848dc03e 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -20,6 +20,7 @@ __author__ = "Vincent Rouvreau, Bertrand Michel, Theo Lacombe"
__copyright__ = "Copyright (C) 2016 Inria"
__license__ = "MIT"
+_gudhi_matplotlib_use_tex = True
def __min_birth_max_death(persistence, band=0.0):
"""This function returns (min_birth, max_death) from the persistence.
@@ -117,10 +118,13 @@ def plot_persistence_barcode(
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
- if _matplotlib_can_use_tex():
- from matplotlib import rc
+ from matplotlib import rc
+ if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
+ else:
+ plt.rc('text', usetex=False)
+ plt.rc('font', family='DejaVu Sans')
if persistence_file != "":
if path.isfile(persistence_file):
@@ -263,10 +267,13 @@ def plot_persistence_diagram(
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
- if _matplotlib_can_use_tex():
- from matplotlib import rc
+ from matplotlib import rc
+ if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
+ else:
+ plt.rc('text', usetex=False)
+ plt.rc('font', family='DejaVu Sans')
if persistence_file != "":
if path.isfile(persistence_file):
@@ -436,10 +443,13 @@ def plot_persistence_density(
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from scipy.stats import kde
- if _matplotlib_can_use_tex():
- from matplotlib import rc
+ from matplotlib import rc
+ if _gudhi_matplotlib_use_tex and _matplotlib_can_use_tex():
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
+ else:
+ plt.rc('text', usetex=False)
+ plt.rc('font', family='DejaVu Sans')
if persistence_file != "":
if dimension is None:
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 4652fe80..994be3b6 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -46,7 +46,7 @@ class KNearestNeighbors:
sort_results (bool): if True, then distances and indices of each point are
sorted on return, so that the first column contains the closest points.
Otherwise, neighbors are returned in an arbitrary order. Defaults to True.
- enable_autodiff (bool): if the input is a torch.tensor, jax.numpy.ndarray or tensorflow.Tensor, this
+ enable_autodiff (bool): if the input is a torch.tensor or tensorflow.Tensor, this
instructs the function to compute distances in a way that works with automatic differentiation.
This is experimental, not supported for all metrics, and requires the package EagerPy.
Defaults to False.
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index 46fee086..5ca127f6 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -1,16 +1,17 @@
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
-# Author(s): Mathieu Carrière
+# Author(s): Mathieu Carrière, Martin Royer
#
-# Copyright (C) 2018-2019 Inria
+# Copyright (C) 2018-2020 Inria
#
# Modification(s):
-# - YYYY/MM Author: Description of the modification
+# - 2020/06 Martin: ATOL integration
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import MinMaxScaler, MaxAbsScaler
from sklearn.neighbors import DistanceMetric
+from sklearn.metrics import pairwise
from .preprocessing import DiagramScaler, BirthPersistenceTransform
@@ -574,3 +575,140 @@ class ComplexPolynomial(BaseEstimator, TransformerMixin):
numpy array with shape (**threshold**): output complex vector of coefficients.
"""
return self.fit_transform([diag])[0,:]
+
+def _lapl_contrast(measure, centers, inertias):
+ """contrast function for vectorising `measure` in ATOL"""
+ return np.exp(-pairwise.pairwise_distances(measure, Y=centers) / inertias)
+
+def _gaus_contrast(measure, centers, inertias):
+ """contrast function for vectorising `measure` in ATOL"""
+ return np.exp(-pairwise.pairwise_distances(measure, Y=centers, squared=True) / inertias**2)
+
+def _indicator_contrast(diags, centers, inertias):
+ """contrast function for vectorising `measure` in ATOL"""
+ robe_curve = np.clip(2-pairwise.pairwise_distances(diags, Y=centers)/inertias, 0, 1)
+ return robe_curve
+
+def _cloud_weighting(measure):
+ """automatic uniform weighting with mass 1 for `measure` in ATOL"""
+ return np.ones(shape=measure.shape[0])
+
+def _iidproba_weighting(measure):
+ """automatic uniform weighting with mass 1/N for `measure` in ATOL"""
+ return np.ones(shape=measure.shape[0]) / measure.shape[0]
+
+class Atol(BaseEstimator, TransformerMixin):
+ """
+ This class allows to vectorise measures (e.g. point clouds, persistence diagrams, etc) after a quantisation step.
+
+ ATOL paper: :cite:`royer2019atol`
+
+ Example
+ --------
+ >>> from sklearn.cluster import KMeans
+ >>> from gudhi.representations.vector_methods import Atol
+ >>> import numpy as np
+ >>> a = np.array([[1, 2, 4], [1, 4, 0], [1, 0, 4]])
+ >>> b = np.array([[4, 2, 0], [4, 4, 0], [4, 0, 2]])
+ >>> c = np.array([[3, 2, -1], [1, 2, -1]])
+ >>> atol_vectoriser = Atol(quantiser=KMeans(n_clusters=2, random_state=202006))
+ >>> atol_vectoriser.fit(X=[a, b, c]).centers
+ array([[ 2. , 0.66666667, 3.33333333],
+ [ 2.6 , 2.8 , -0.4 ]])
+ >>> atol_vectoriser(a)
+ array([1.18168665, 0.42375966])
+ >>> atol_vectoriser(c)
+ array([0.02062512, 1.25157463])
+ >>> atol_vectoriser.transform(X=[a, b, c])
+ array([[1.18168665, 0.42375966],
+ [0.29861028, 1.06330156],
+ [0.02062512, 1.25157463]])
+ """
+ def __init__(self, quantiser, weighting_method="cloud", contrast="gaussian"):
+ """
+ Constructor for the Atol measure vectorisation class.
+
+ Parameters:
+ quantiser (Object): Object with `fit` (sklearn API consistent) and `cluster_centers` and `n_clusters`
+ attributes, e.g. sklearn.cluster.KMeans. It will be fitted when the Atol object function `fit` is called.
+ weighting_method (string): constant generic function for weighting the measure points
+ choose from {"cloud", "iidproba"}
+ (default: constant function, i.e. the measure is seen as a point cloud by default).
+ This will have no impact if weights are provided along with measures all the way: `fit` and `transform`.
+ contrast (string): constant function for evaluating proximity of a measure with respect to centers
+ choose from {"gaussian", "laplacian", "indicator"}
+ (default: gaussian contrast function, see page 3 in the ATOL paper).
+ """
+ self.quantiser = quantiser
+ self.contrast = {
+ "gaussian": _gaus_contrast,
+ "laplacian": _lapl_contrast,
+ "indicator": _indicator_contrast,
+ }.get(contrast, _gaus_contrast)
+ self.weighting_method = {
+ "cloud" : _cloud_weighting,
+ "iidproba": _iidproba_weighting,
+ }.get(weighting_method, _cloud_weighting)
+
+ def fit(self, X, y=None, sample_weight=None):
+ """
+ Calibration step: fit centers to the sample measures and derive inertias between centers.
+
+ Parameters:
+ X (list N x d numpy arrays): input measures in R^d from which to learn center locations and inertias
+ (measures can have different N).
+ y: Ignored, present for API consistency by convention.
+ sample_weight (list of numpy arrays): weights for each measure point in X, optional.
+ If None, the object's weighting_method will be used.
+
+ Returns:
+ self
+ """
+ if not hasattr(self.quantiser, 'fit'):
+ raise TypeError("quantiser %s has no `fit` attribute." % (self.quantiser))
+ if sample_weight is None:
+ sample_weight = np.concatenate([self.weighting_method(measure) for measure in X])
+
+ measures_concat = np.concatenate(X)
+ self.quantiser.fit(X=measures_concat, sample_weight=sample_weight)
+ self.centers = self.quantiser.cluster_centers_
+ if self.quantiser.n_clusters == 1:
+ dist_centers = pairwise.pairwise_distances(measures_concat)
+ np.fill_diagonal(dist_centers, 0)
+ self.inertias = np.array([np.max(dist_centers)/2])
+ else:
+ dist_centers = pairwise.pairwise_distances(self.centers)
+ dist_centers[dist_centers == 0] = np.inf
+ self.inertias = np.min(dist_centers, axis=0)/2
+ return self
+
+ def __call__(self, measure, sample_weight=None):
+ """
+ Apply measure vectorisation on a single measure.
+
+ Parameters:
+ measure (n x d numpy array): input measure in R^d.
+
+ Returns:
+ numpy array in R^self.quantiser.n_clusters.
+ """
+ if sample_weight is None:
+ sample_weight = self.weighting_method(measure)
+ return np.sum(sample_weight * self.contrast(measure, self.centers, self.inertias.T).T, axis=1)
+
+ def transform(self, X, sample_weight=None):
+ """
+ Apply measure vectorisation on a list of measures.
+
+ Parameters:
+ X (list N x d numpy arrays): input measures in R^d from which to learn center locations and inertias
+ (measures can have different N).
+ sample_weight (list of numpy arrays): weights for each measure point in X, optional.
+ If None, the object's weighting_method will be used.
+
+ Returns:
+ numpy array with shape (number of measures) x (self.quantiser.n_clusters).
+ """
+ if sample_weight is None:
+ sample_weight = [self.weighting_method(measure) for measure in X]
+ return np.stack([self(measure, sample_weight=weight) for measure, weight in zip(X, sample_weight)])
diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd
index 12c2065e..3b494ba3 100644
--- a/src/python/gudhi/simplex_tree.pxd
+++ b/src/python/gudhi/simplex_tree.pxd
@@ -57,6 +57,7 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi":
bool make_filtration_non_decreasing() nogil
void compute_extended_filtration() nogil
vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(vector[pair[int, pair[double, double]]] dgm, double min_persistence) nogil
+ Simplex_tree_interface_full_featured* collapse_edges(int nb_collapse_iteration) nogil
void reset_filtration(double filtration, int dimension) nogil
# Iterators over Simplex tree
pair[vector[int], double] get_simplex_and_filtration(Simplex_tree_simplex_handle f_simplex) nogil
diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx
index 41b06116..b7682693 100644
--- a/src/python/gudhi/simplex_tree.pyx
+++ b/src/python/gudhi/simplex_tree.pyx
@@ -69,7 +69,7 @@ cdef class SimplexTree:
this simplicial complex, or +infinity if it is not in the complex.
:param simplex: The N-simplex, represented by a list of vertex.
- :type simplex: list of int.
+ :type simplex: list of int
:returns: The simplicial complex filtration value.
:rtype: float
"""
@@ -80,7 +80,7 @@ cdef class SimplexTree:
given N-simplex.
:param simplex: The N-simplex, represented by a list of vertex.
- :type simplex: list of int.
+ :type simplex: list of int
:param filtration: The new filtration value.
:type filtration: float
@@ -153,7 +153,7 @@ cdef class SimplexTree:
"""This function sets the dimension of the simplicial complex.
:param dimension: The new dimension value.
- :type dimension: int.
+ :type dimension: int
.. note::
@@ -172,7 +172,7 @@ cdef class SimplexTree:
complex or not.
:param simplex: The N-simplex to find, represented by a list of vertex.
- :type simplex: list of int.
+ :type simplex: list of int
:returns: true if the simplex was found, false otherwise.
:rtype: bool
"""
@@ -186,9 +186,9 @@ cdef class SimplexTree:
:param simplex: The N-simplex to insert, represented by a list of
vertex.
- :type simplex: list of int.
+ :type simplex: list of int
:param filtration: The filtration value of the simplex.
- :type filtration: float.
+ :type filtration: float
:returns: true if the simplex was not yet in the complex, false
otherwise (whatever its original filtration value).
:rtype: bool
@@ -228,7 +228,7 @@ cdef class SimplexTree:
"""This function returns a generator with the (simplices of the) skeleton of a maximum given dimension.
:param dimension: The skeleton dimension value.
- :type dimension: int.
+ :type dimension: int
:returns: The (simplices of the) skeleton of a maximum dimension.
:rtype: generator with tuples(simplex, filtration)
"""
@@ -243,7 +243,7 @@ cdef class SimplexTree:
"""This function returns the star of a given N-simplex.
:param simplex: The N-simplex, represented by a list of vertex.
- :type simplex: list of int.
+ :type simplex: list of int
:returns: The (simplices of the) star of a simplex.
:rtype: list of tuples(simplex, filtration)
"""
@@ -265,10 +265,10 @@ cdef class SimplexTree:
given codimension.
:param simplex: The N-simplex, represented by a list of vertex.
- :type simplex: list of int.
+ :type simplex: list of int
:param codimension: The codimension. If codimension = 0, all cofaces
are returned (equivalent of get_star function)
- :type codimension: int.
+ :type codimension: int
:returns: The (simplices of the) cofaces of a simplex
:rtype: list of tuples(simplex, filtration)
"""
@@ -290,7 +290,7 @@ cdef class SimplexTree:
complex.
:param simplex: The N-simplex, represented by a list of vertex.
- :type simplex: list of int.
+ :type simplex: list of int
.. note::
@@ -308,7 +308,7 @@ cdef class SimplexTree:
"""Prune above filtration value given as parameter.
:param filtration: Maximum threshold value.
- :type filtration: float.
+ :type filtration: float
:returns: The filtration modification information.
:rtype: bool
@@ -342,7 +342,7 @@ cdef class SimplexTree:
1 when calling the method.
:param max_dim: The maximal dimension.
- :type max_dim: int.
+ :type max_dim: int
"""
cdef int maxdim = max_dim
with nogil:
@@ -393,12 +393,12 @@ cdef class SimplexTree:
:param homology_coeff_field: The homology coefficient field. Must be a
prime number. Default value is 11.
- :type homology_coeff_field: int.
+ :type homology_coeff_field: int
:param min_persistence: The minimum persistence value (i.e., the absolute value of the difference between the persistence diagram point coordinates) to take into
account (strictly greater than min_persistence). Default value is
0.0.
Sets min_persistence to -1.0 to see all values.
- :type min_persistence: float.
+ :type min_persistence: float
:returns: A list of four persistence diagrams in the format described in :func:`persistence`. The first one is Ordinary, the second one is Relative, the third one is Extended+ and the fourth one is Extended-. See https://link.springer.com/article/10.1007/s10208-008-9027-z and/or section 2.2 in https://link.springer.com/article/10.1007/s10208-017-9370-z for a description of these subtypes.
.. note::
@@ -425,12 +425,12 @@ cdef class SimplexTree:
:param homology_coeff_field: The homology coefficient field. Must be a
prime number. Default value is 11.
- :type homology_coeff_field: int.
+ :type homology_coeff_field: int
:param min_persistence: The minimum persistence value to take into
account (strictly greater than min_persistence). Default value is
0.0.
Set min_persistence to -1.0 to see all values.
- :type min_persistence: float.
+ :type min_persistence: float
:param persistence_dim_max: If true, the persistent homology for the
maximal dimension in the complex is computed. If false, it is
ignored. Default is false.
@@ -448,12 +448,12 @@ cdef class SimplexTree:
:param homology_coeff_field: The homology coefficient field. Must be a
prime number. Default value is 11.
- :type homology_coeff_field: int.
+ :type homology_coeff_field: int
:param min_persistence: The minimum persistence value to take into
account (strictly greater than min_persistence). Default value is
0.0.
Sets min_persistence to -1.0 to see all values.
- :type min_persistence: float.
+ :type min_persistence: float
:param persistence_dim_max: If true, the persistent homology for the
maximal dimension in the complex is computed. If false, it is
ignored. Default is false.
@@ -488,10 +488,10 @@ cdef class SimplexTree:
:param from_value: The persistence birth limit to be added in the
numbers (persistent birth <= from_value).
- :type from_value: float.
+ :type from_value: float
:param to_value: The persistence death limit to be added in the
numbers (persistent death > to_value).
- :type to_value: float.
+ :type to_value: float
:returns: The persistent Betti numbers ([B0, B1, ..., Bn]).
:rtype: list of int
@@ -508,7 +508,7 @@ cdef class SimplexTree:
complex in a specific dimension.
:param dimension: The specific dimension.
- :type dimension: int.
+ :type dimension: int
:returns: The persistence intervals.
:rtype: numpy array of dimension 2
@@ -537,7 +537,7 @@ cdef class SimplexTree:
complex in a user given file name.
:param persistence_file: Name of the file.
- :type persistence_file: string.
+ :type persistence_file: string
:note: intervals_in_dim function requires
:func:`compute_persistence`
@@ -591,3 +591,23 @@ cdef class SimplexTree:
infinite0 = np_array(next(l))
infinites = [np_array(d).reshape(-1,2) for d in l]
return (normal0, normals, infinite0, infinites)
+
+ def collapse_edges(self, nb_iterations = 1):
+ """Assuming the simplex tree is a 1-skeleton graph, this method collapse edges (simplices of higher dimension
+ are ignored) and resets the simplex tree from the remaining edges.
+ A good candidate is to build a simplex tree on top of a :class:`~gudhi.RipsComplex` of dimension 1 before
+ collapsing edges
+ (cf. :download:`rips_complex_edge_collapse_example.py <../example/rips_complex_edge_collapse_example.py>`).
+ For implementation details, please refer to :cite:`edgecollapsesocg2020`.
+
+ :param nb_iterations: The number of edge collapse iterations to perform. Default is 1.
+ :type nb_iterations: int
+ """
+ # Backup old pointer
+ cdef Simplex_tree_interface_full_featured* ptr = self.get_ptr()
+ cdef int nb_iter = nb_iterations
+ with nogil:
+ # New pointer is a new collapsed simplex tree
+ self.thisptr = <intptr_t>(ptr.collapse_edges(nb_iter))
+ # Delete old pointer
+ del ptr
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index b37d30bb..a9d1cdff 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -99,7 +99,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
:param order: exponent for Wasserstein; Default value is 1.
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
Default value is `np.inf`.
- :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
+ :param enable_autodiff: If X and Y are torch.tensor or tensorflow.Tensor, make the computation
transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible
with `matching=True`.
@@ -165,9 +165,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
# empty arrays are not handled properly by the helpers, so we avoid calling them
if len(pairs_X_Y):
dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order))
- if len(pairs_X_diag):
+ if len(pairs_X_diag[0]):
dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p))
- if len(pairs_Y_diag):
+ if len(pairs_Y_diag[0]):
dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p))
dists = [dist.reshape(1) for dist in dists]
return ep.concatenate(dists).norms.lp(order).raw