diff options
author | MathieuCarriere <mathieu.carriere3@gmail.com> | 2020-04-28 13:48:45 -0400 |
---|---|---|
committer | MathieuCarriere <mathieu.carriere3@gmail.com> | 2020-04-28 13:48:45 -0400 |
commit | 4923f2bd8a18d2f66288f39c08309cb7cafa5627 (patch) | |
tree | 0f9572654e52fc0b0bc7994f07aee1a874c2a45a /src/python/gudhi | |
parent | 39b6731486838b8f2e608e5b5738c12e1c83266f (diff) | |
parent | 0fb22e4c499b665ad505e5d9d2c325f7561f69c4 (diff) |
fix conflict
Diffstat (limited to 'src/python/gudhi')
-rw-r--r-- | src/python/gudhi/alpha_complex.pyx | 10 | ||||
-rw-r--r-- | src/python/gudhi/cubical_complex.pyx | 80 | ||||
-rw-r--r-- | src/python/gudhi/nerve_gic.pyx | 37 | ||||
-rw-r--r-- | src/python/gudhi/off_reader.pyx | 12 | ||||
-rw-r--r-- | src/python/gudhi/periodic_cubical_complex.pyx | 73 | ||||
-rw-r--r-- | src/python/gudhi/point_cloud/dtm.py | 70 | ||||
-rw-r--r-- | src/python/gudhi/point_cloud/knn.py | 324 | ||||
-rw-r--r-- | src/python/gudhi/simplex_tree.pxd | 10 | ||||
-rw-r--r-- | src/python/gudhi/simplex_tree.pyx | 203 | ||||
-rw-r--r-- | src/python/gudhi/wasserstein/__init__.py | 1 | ||||
-rw-r--r-- | src/python/gudhi/wasserstein/barycenter.py | 159 | ||||
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py (renamed from src/python/gudhi/wasserstein.py) | 42 |
12 files changed, 818 insertions, 203 deletions
diff --git a/src/python/gudhi/alpha_complex.pyx b/src/python/gudhi/alpha_complex.pyx index fff3e920..e04dc652 100644 --- a/src/python/gudhi/alpha_complex.pyx +++ b/src/python/gudhi/alpha_complex.pyx @@ -1,5 +1,7 @@ -# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. -# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - +# which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full +# license details. # Author(s): Vincent Rouvreau # # Copyright (C) 2016 Inria @@ -7,6 +9,7 @@ # Modification(s): # - YYYY/MM Author: Description of the modification +from __future__ import print_function from cython cimport numeric from libcpp.vector cimport vector from libcpp.utility cimport pair @@ -69,7 +72,8 @@ cdef class AlphaComplex: def __cinit__(self, points = None, off_file = ''): if off_file: if os.path.isfile(off_file): - self.thisptr = new Alpha_complex_interface(off_file.encode('utf-8'), True) + self.thisptr = new Alpha_complex_interface( + off_file.encode('utf-8'), True) else: print("file " + off_file + " not found.") else: diff --git a/src/python/gudhi/cubical_complex.pyx b/src/python/gudhi/cubical_complex.pyx index 84fec60e..69d0f0b6 100644 --- a/src/python/gudhi/cubical_complex.pyx +++ b/src/python/gudhi/cubical_complex.pyx @@ -1,5 +1,7 @@ -# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. -# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - +# which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full +# license details. # Author(s): Vincent Rouvreau # # Copyright (C) 2016 Inria @@ -7,12 +9,15 @@ # Modification(s): # - YYYY/MM Author: Description of the modification +from __future__ import print_function from cython cimport numeric from libcpp.vector cimport vector from libcpp.utility cimport pair from libcpp.string cimport string from libcpp cimport bool +import errno import os +import sys import numpy as np @@ -30,7 +35,8 @@ cdef extern from "Cubical_complex_interface.h" namespace "Gudhi": cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi": cdef cppclass Cubical_complex_persistence_interface "Gudhi::Persistent_cohomology_interface<Gudhi::Cubical_complex::Cubical_complex_interface<>>": Cubical_complex_persistence_interface(Bitmap_cubical_complex_base_interface * st, bool persistence_dim_max) - vector[pair[int, pair[double, double]]] get_persistence(int homology_coeff_field, double min_persistence) + void compute_persistence(int homology_coeff_field, double min_persistence) + vector[pair[int, pair[double, double]]] get_persistence() vector[vector[int]] cofaces_of_cubical_persistence_pairs() vector[int] betti_numbers() vector[int] persistent_betti_numbers(double from_value, double to_value) @@ -88,10 +94,12 @@ cdef class CubicalComplex: if os.path.isfile(perseus_file): self.thisptr = new Bitmap_cubical_complex_base_interface(perseus_file.encode('utf-8')) else: - print("file " + perseus_file + " not found.") + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), + perseus_file) else: print("CubicalComplex can be constructed from dimensions and " - "top_dimensional_cells or from a Perseus-style file name.") + "top_dimensional_cells or from a Perseus-style file name.", + file=sys.stderr) def __dealloc__(self): if self.thisptr != NULL: @@ -123,8 +131,31 @@ cdef class CubicalComplex: """ return self.thisptr.dimension() + def compute_persistence(self, homology_coeff_field=11, min_persistence=0): + """This function computes the persistence of the complex, so it can be + accessed through :func:`persistent_betti_numbers`, + :func:`persistence_intervals_in_dimension`, etc. This function is + equivalent to :func:`persistence` when you do not want the list + :func:`persistence` returns. + + :param homology_coeff_field: The homology coefficient field. Must be a + prime number + :type homology_coeff_field: int. + :param min_persistence: The minimum persistence value to take into + account (strictly greater than min_persistence). Default value is + 0.0. + Sets min_persistence to -1.0 to see all values. + :type min_persistence: float. + :returns: Nothing. + """ + if self.pcohptr != NULL: + del self.pcohptr + assert self.__is_defined() + self.pcohptr = new Cubical_complex_persistence_interface(self.thisptr, True) + self.pcohptr.compute_persistence(homology_coeff_field, min_persistence) + def persistence(self, homology_coeff_field=11, min_persistence=0): - """This function returns the persistence of the complex. + """This function computes and returns the persistence of the complex. :param homology_coeff_field: The homology coefficient field. Must be a prime number @@ -137,14 +168,8 @@ cdef class CubicalComplex: :returns: list of pairs(dimension, pair(birth, death)) -- the persistence of the complex. """ - if self.pcohptr != NULL: - del self.pcohptr - if self.thisptr != NULL: - self.pcohptr = new Cubical_complex_persistence_interface(self.thisptr, True) - cdef vector[pair[int, pair[double, double]]] persistence_result - if self.pcohptr != NULL: - persistence_result = self.pcohptr.get_persistence(homology_coeff_field, min_persistence) - return persistence_result + self.compute_persistence(homology_coeff_field, min_persistence) + return self.pcohptr.get_persistence() def cofaces_of_persistence_pairs(self): """A persistence interval is described by a pair of cells, one that creates the @@ -180,16 +205,14 @@ cdef class CubicalComplex: :returns: list of int -- The Betti numbers ([B0, B1, ..., Bn]). - :note: betti_numbers function requires persistence function to be + :note: betti_numbers function requires :func:`compute_persistence` function to be launched first. :note: betti_numbers function always returns [1, 0, 0, ...] as infinity filtration cubes are not removed from the complex. """ - cdef vector[int] bn_result - if self.pcohptr != NULL: - bn_result = self.pcohptr.betti_numbers() - return bn_result + assert self.pcohptr != NULL, "compute_persistence() must be called before betti_numbers()" + return self.pcohptr.betti_numbers() def persistent_betti_numbers(self, from_value, to_value): """This function returns the persistent Betti numbers of the complex. @@ -204,13 +227,11 @@ cdef class CubicalComplex: :returns: list of int -- The persistent Betti numbers ([B0, B1, ..., Bn]). - :note: persistent_betti_numbers function requires persistence + :note: persistent_betti_numbers function requires :func:`compute_persistence` function to be launched first. """ - cdef vector[int] pbn_result - if self.pcohptr != NULL: - pbn_result = self.pcohptr.persistent_betti_numbers(<double>from_value, <double>to_value) - return pbn_result + assert self.pcohptr != NULL, "compute_persistence() must be called before persistent_betti_numbers()" + return self.pcohptr.persistent_betti_numbers(<double>from_value, <double>to_value) def persistence_intervals_in_dimension(self, dimension): """This function returns the persistence intervals of the complex in a @@ -221,13 +242,8 @@ cdef class CubicalComplex: :returns: The persistence intervals. :rtype: numpy array of dimension 2 - :note: intervals_in_dim function requires persistence function to be + :note: intervals_in_dim function requires :func:`compute_persistence` function to be launched first. """ - cdef vector[pair[double,double]] intervals_result - if self.pcohptr != NULL: - intervals_result = self.pcohptr.intervals_in_dimension(dimension) - else: - print("intervals_in_dim function requires persistence function" - " to be launched first.") - return np.array(intervals_result) + assert self.pcohptr != NULL, "compute_persistence() must be called before persistence_intervals_in_dimension()" + return np.array(self.pcohptr.intervals_in_dimension(dimension)) diff --git a/src/python/gudhi/nerve_gic.pyx b/src/python/gudhi/nerve_gic.pyx index 45cc8eba..9c89b239 100644 --- a/src/python/gudhi/nerve_gic.pyx +++ b/src/python/gudhi/nerve_gic.pyx @@ -1,5 +1,7 @@ -# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. -# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - +# which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full +# license details. # Author(s): Vincent Rouvreau # # Copyright (C) 2018 Inria @@ -7,11 +9,13 @@ # Modification(s): # - YYYY/MM Author: Description of the modification +from __future__ import print_function from cython cimport numeric from libcpp.vector cimport vector from libcpp.utility cimport pair from libcpp.string cimport string from libcpp cimport bool +import errno import os from libc.stdint cimport intptr_t @@ -96,7 +100,8 @@ cdef class CoverComplex: return self.thisptr != NULL def set_point_cloud_from_range(self, cloud): - """ Reads and stores the input point cloud from a vector stored in memory. + """ Reads and stores the input point cloud from a vector stored in + memory. :param cloud: Input vector containing the point cloud. :type cloud: vector[vector[double]] @@ -104,7 +109,8 @@ cdef class CoverComplex: return self.thisptr.set_point_cloud_from_range(cloud) def set_distances_from_range(self, distance_matrix): - """ Reads and stores the input distance matrix from a vector stored in memory. + """ Reads and stores the input distance matrix from a vector stored in + memory. :param distance_matrix: Input vector containing the distance matrix. :type distance_matrix: vector[vector[double]] @@ -163,7 +169,8 @@ cdef class CoverComplex: """ stree = SimplexTree() cdef intptr_t stree_int_ptr=stree.thisptr - self.thisptr.create_simplex_tree(<Simplex_tree_interface_full_featured*>stree_int_ptr) + self.thisptr.create_simplex_tree( + <Simplex_tree_interface_full_featured*>stree_int_ptr) return stree def find_simplices(self): @@ -182,8 +189,8 @@ cdef class CoverComplex: if os.path.isfile(off_file): return self.thisptr.read_point_cloud(off_file.encode('utf-8')) else: - print("file " + off_file + " not found.") - return False + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), + off_file) def set_automatic_resolution(self): """Computes the optimal length of intervals (i.e. the smallest interval @@ -214,7 +221,8 @@ cdef class CoverComplex: if os.path.isfile(color_file_name): self.thisptr.set_color_from_file(color_file_name.encode('utf-8')) else: - print("file " + color_file_name + " not found.") + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), + color_file_name) def set_color_from_range(self, color): """Computes the function used to color the nodes of the simplicial @@ -235,7 +243,8 @@ cdef class CoverComplex: if os.path.isfile(cover_file_name): self.thisptr.set_cover_from_file(cover_file_name.encode('utf-8')) else: - print("file " + cover_file_name + " not found.") + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), + cover_file_name) def set_cover_from_function(self): """Creates a cover C from the preimages of the function f. @@ -268,7 +277,8 @@ cdef class CoverComplex: if os.path.isfile(func_file_name): self.thisptr.set_function_from_file(func_file_name.encode('utf-8')) else: - print("file " + func_file_name + " not found.") + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), + func_file_name) def set_function_from_range(self, function): """Creates the function f from a vector stored in memory. @@ -302,14 +312,15 @@ cdef class CoverComplex: """Creates a graph G from a file containing the edges. :param graph_file_name: Name of the input graph file. The graph file - contains one edge per line, each edge being represented by the IDs of - its two nodes. + contains one edge per line, each edge being represented by the IDs + of its two nodes. :type graph_file_name: string """ if os.path.isfile(graph_file_name): self.thisptr.set_graph_from_file(graph_file_name.encode('utf-8')) else: - print("file " + graph_file_name + " not found.") + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), + graph_file_name) def set_graph_from_OFF(self): """Creates a graph G from the triangulation given by the input OFF diff --git a/src/python/gudhi/off_reader.pyx b/src/python/gudhi/off_reader.pyx index 7e6d9d80..a3200704 100644 --- a/src/python/gudhi/off_reader.pyx +++ b/src/python/gudhi/off_reader.pyx @@ -1,5 +1,7 @@ -# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. -# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - +# which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full +# license details. # Author(s): Vincent Rouvreau # # Copyright (C) 2016 Inria @@ -7,9 +9,11 @@ # Modification(s): # - YYYY/MM Author: Description of the modification +from __future__ import print_function from cython cimport numeric from libcpp.vector cimport vector from libcpp.string cimport string +import errno import os __author__ = "Vincent Rouvreau" @@ -32,6 +36,6 @@ def read_points_from_off_file(off_file=''): if os.path.isfile(off_file): return read_points_from_OFF_file(off_file.encode('utf-8')) else: - print("file " + off_file + " not found.") - return [] + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), + off_file) diff --git a/src/python/gudhi/periodic_cubical_complex.pyx b/src/python/gudhi/periodic_cubical_complex.pyx index 993d95c7..78565cf8 100644 --- a/src/python/gudhi/periodic_cubical_complex.pyx +++ b/src/python/gudhi/periodic_cubical_complex.pyx @@ -7,11 +7,13 @@ # Modification(s): # - YYYY/MM Author: Description of the modification +from __future__ import print_function from cython cimport numeric from libcpp.vector cimport vector from libcpp.utility cimport pair from libcpp.string cimport string from libcpp cimport bool +import sys import os import numpy as np @@ -30,7 +32,8 @@ cdef extern from "Cubical_complex_interface.h" namespace "Gudhi": cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi": cdef cppclass Periodic_cubical_complex_persistence_interface "Gudhi::Persistent_cohomology_interface<Gudhi::Cubical_complex::Cubical_complex_interface<Gudhi::cubical_complex::Bitmap_cubical_complex_periodic_boundary_conditions_base<double>>>": Periodic_cubical_complex_persistence_interface(Periodic_cubical_complex_base_interface * st, bool persistence_dim_max) - vector[pair[int, pair[double, double]]] get_persistence(int homology_coeff_field, double min_persistence) + void compute_persistence(int homology_coeff_field, double min_persistence) + vector[pair[int, pair[double, double]]] get_persistence() vector[vector[int]] cofaces_of_cubical_persistence_pairs() vector[int] betti_numbers() vector[int] persistent_betti_numbers(double from_value, double to_value) @@ -96,12 +99,12 @@ cdef class PeriodicCubicalComplex: if os.path.isfile(perseus_file): self.thisptr = new Periodic_cubical_complex_base_interface(perseus_file.encode('utf-8')) else: - print("file " + perseus_file + " not found.") + print("file " + perseus_file + " not found.", file=sys.stderr) else: print("CubicalComplex can be constructed from dimensions, " "top_dimensional_cells and periodic_dimensions, or from " "top_dimensional_cells and periodic_dimensions or from " - "a Perseus-style file name.") + "a Perseus-style file name.", file=sys.stderr) def __dealloc__(self): if self.thisptr != NULL: @@ -133,8 +136,31 @@ cdef class PeriodicCubicalComplex: """ return self.thisptr.dimension() + def compute_persistence(self, homology_coeff_field=11, min_persistence=0): + """This function computes the persistence of the complex, so it can be + accessed through :func:`persistent_betti_numbers`, + :func:`persistence_intervals_in_dimension`, etc. This function is + equivalent to :func:`persistence` when you do not want the list + :func:`persistence` returns. + + :param homology_coeff_field: The homology coefficient field. Must be a + prime number + :type homology_coeff_field: int. + :param min_persistence: The minimum persistence value to take into + account (strictly greater than min_persistence). Default value is + 0.0. + Sets min_persistence to -1.0 to see all values. + :type min_persistence: float. + :returns: Nothing. + """ + if self.pcohptr != NULL: + del self.pcohptr + assert self.__is_defined() + self.pcohptr = new Periodic_cubical_complex_persistence_interface(self.thisptr, True) + self.pcohptr.compute_persistence(homology_coeff_field, min_persistence) + def persistence(self, homology_coeff_field=11, min_persistence=0): - """This function returns the persistence of the complex. + """This function computes and returns the persistence of the complex. :param homology_coeff_field: The homology coefficient field. Must be a prime number @@ -147,14 +173,8 @@ cdef class PeriodicCubicalComplex: :returns: list of pairs(dimension, pair(birth, death)) -- the persistence of the complex. """ - if self.pcohptr != NULL: - del self.pcohptr - if self.thisptr != NULL: - self.pcohptr = new Periodic_cubical_complex_persistence_interface(self.thisptr, True) - cdef vector[pair[int, pair[double, double]]] persistence_result - if self.pcohptr != NULL: - persistence_result = self.pcohptr.get_persistence(homology_coeff_field, min_persistence) - return persistence_result + self.compute_persistence(homology_coeff_field, min_persistence) + return self.pcohptr.get_persistence() def cofaces_of_persistence_pairs(self): """A persistence interval is described by a pair of cells, one that creates the @@ -190,16 +210,14 @@ cdef class PeriodicCubicalComplex: :returns: list of int -- The Betti numbers ([B0, B1, ..., Bn]). - :note: betti_numbers function requires persistence function to be + :note: betti_numbers function requires :func:`compute_persistence` function to be launched first. - :note: betti_numbers function always returns [1, 0, 0, ...] as infinity + :note: This function always returns the Betti numbers of a torus as infinity filtration cubes are not removed from the complex. """ - cdef vector[int] bn_result - if self.pcohptr != NULL: - bn_result = self.pcohptr.betti_numbers() - return bn_result + assert self.pcohptr != NULL, "compute_persistence() must be called before betti_numbers()" + return self.pcohptr.betti_numbers() def persistent_betti_numbers(self, from_value, to_value): """This function returns the persistent Betti numbers of the complex. @@ -214,13 +232,11 @@ cdef class PeriodicCubicalComplex: :returns: list of int -- The persistent Betti numbers ([B0, B1, ..., Bn]). - :note: persistent_betti_numbers function requires persistence + :note: persistent_betti_numbers function requires :func:`compute_persistence` function to be launched first. """ - cdef vector[int] pbn_result - if self.pcohptr != NULL: - pbn_result = self.pcohptr.persistent_betti_numbers(<double>from_value, <double>to_value) - return pbn_result + assert self.pcohptr != NULL, "compute_persistence() must be called before persistent_betti_numbers()" + return self.pcohptr.persistent_betti_numbers(<double>from_value, <double>to_value) def persistence_intervals_in_dimension(self, dimension): """This function returns the persistence intervals of the complex in a @@ -231,13 +247,8 @@ cdef class PeriodicCubicalComplex: :returns: The persistence intervals. :rtype: numpy array of dimension 2 - :note: intervals_in_dim function requires persistence function to be + :note: intervals_in_dim function requires :func:`compute_persistence` function to be launched first. """ - cdef vector[pair[double,double]] intervals_result - if self.pcohptr != NULL: - intervals_result = self.pcohptr.intervals_in_dimension(dimension) - else: - print("intervals_in_dim function requires persistence function" - " to be launched first.") - return np.array(intervals_result) + assert self.pcohptr != NULL, "compute_persistence() must be called before persistence_intervals_in_dimension()" + return np.array(self.pcohptr.intervals_in_dimension(dimension)) diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py new file mode 100644 index 00000000..13e16d24 --- /dev/null +++ b/src/python/gudhi/point_cloud/dtm.py @@ -0,0 +1,70 @@ +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# Author(s): Marc Glisse +# +# Copyright (C) 2020 Inria +# +# Modification(s): +# - YYYY/MM Author: Description of the modification + +from .knn import KNearestNeighbors + +__author__ = "Marc Glisse" +__copyright__ = "Copyright (C) 2020 Inria" +__license__ = "MIT" + + +class DistanceToMeasure: + """ + Class to compute the distance to the empirical measure defined by a point set, as introduced in :cite:`dtm`. + """ + + def __init__(self, k, q=2, **kwargs): + """ + Args: + k (int): number of neighbors (possibly including the point itself). + q (float): order used to compute the distance to measure. Defaults to 2. + kwargs: same parameters as :class:`~gudhi.point_cloud.knn.KNearestNeighbors`, except that + metric="neighbors" means that :func:`transform` expects an array with the distances + to the k nearest neighbors. + """ + self.k = k + self.q = q + self.params = kwargs + + def fit_transform(self, X, y=None): + return self.fit(X).transform(X) + + def fit(self, X, y=None): + """ + Args: + X (numpy.array): coordinates for mass points. + """ + if self.params.setdefault("metric", "euclidean") != "neighbors": + self.knn = KNearestNeighbors( + self.k, return_index=False, return_distance=True, sort_results=False, **self.params + ) + self.knn.fit(X) + return self + + def transform(self, X): + """ + Args: + X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed", + or distances to the k nearest neighbors if metric is "neighbors" (if the array has more + than k columns, the remaining ones are ignored). + + Returns: + numpy.array: a 1-d array with, for each point of X, its distance to the measure defined + by the argument of :func:`fit`. + """ + if self.params["metric"] == "neighbors": + distances = X[:, : self.k] + else: + distances = self.knn.transform(X) + distances = distances ** self.q + dtm = distances.sum(-1) / self.k + dtm = dtm ** (1.0 / self.q) + # We compute too many powers, 1/p in knn then q in dtm, 1/q in dtm then q or some log in the caller. + # Add option to skip the final root? + return dtm diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py new file mode 100644 index 00000000..07553d6d --- /dev/null +++ b/src/python/gudhi/point_cloud/knn.py @@ -0,0 +1,324 @@ +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# Author(s): Marc Glisse +# +# Copyright (C) 2020 Inria +# +# Modification(s): +# - YYYY/MM Author: Description of the modification + +import numpy + +# TODO: https://github.com/facebookresearch/faiss + +__author__ = "Marc Glisse" +__copyright__ = "Copyright (C) 2020 Inria" +__license__ = "MIT" + + +class KNearestNeighbors: + """ + Class wrapping several implementations for computing the k nearest neighbors in a point set. + """ + + def __init__(self, k, return_index=True, return_distance=False, metric="euclidean", **kwargs): + """ + Args: + k (int): number of neighbors (possibly including the point itself). + return_index (bool): if True, return the index of each neighbor. + return_distance (bool): if True, return the distance to each neighbor. + implementation (str): choice of the library that does the real work. + + * 'keops' for a brute-force, CUDA implementation through pykeops. Useful when the dimension becomes large (10+) but the number of points remains low (less than a million). Only "minkowski" and its aliases are supported. + * 'ckdtree' for scipy's cKDTree. Only "minkowski" and its aliases are supported. + * 'sklearn' for scikit-learn's NearestNeighbors. Note that this provides in particular an option algorithm="brute". + * 'hnsw' for hnswlib.Index. It can be very fast but does not provide guarantees. Only supports "euclidean" for now. + * None will try to select a sensible one (scipy if possible, scikit-learn otherwise). + metric (str): see `sklearn.neighbors.NearestNeighbors`. + eps (float): relative error when computing nearest neighbors with the cKDTree. + p (float): norm L^p on input points (including numpy.inf) if metric is "minkowski". Defaults to 2. + n_jobs (int): number of jobs to schedule for parallel processing of nearest neighbors on the CPU. + If -1 is given all processors are used. Default: 1. + sort_results (bool): if True, then distances and indices of each point are + sorted on return, so that the first column contains the closest points. + Otherwise, neighbors are returned in an arbitrary order. Defaults to True. + enable_autodiff (bool): if the input is a torch.tensor, jax.numpy.ndarray or tensorflow.Tensor, this + instructs the function to compute distances in a way that works with automatic differentiation. + This is experimental, not supported for all metrics, and requires the package EagerPy. + Defaults to False. + kwargs: additional parameters are forwarded to the backends. + """ + self.k = k + self.return_index = return_index + self.return_distance = return_distance + self.metric = metric + self.params = kwargs + # canonicalize + if metric == "euclidean": + self.params["p"] = 2 + self.metric = "minkowski" + elif metric == "manhattan": + self.params["p"] = 1 + self.metric = "minkowski" + elif metric == "chebyshev": + self.params["p"] = numpy.inf + self.metric = "minkowski" + elif metric == "minkowski": + self.params["p"] = kwargs.get("p", 2) + if self.params.get("implementation") in {"keops", "ckdtree"}: + assert self.metric == "minkowski" + if self.params.get("implementation") == "hnsw": + assert self.metric == "minkowski" and self.params["p"] == 2 + if not self.params.get("implementation"): + if self.metric == "minkowski": + self.params["implementation"] = "ckdtree" + else: + self.params["implementation"] = "sklearn" + if not return_distance: + self.params["enable_autodiff"] = False + + def fit_transform(self, X, y=None): + return self.fit(X).transform(X) + + def fit(self, X, y=None): + """ + Args: + X (numpy.array): coordinates for reference points. + """ + self.ref_points = X + if self.params.get("enable_autodiff", False): + import eagerpy as ep + + X = ep.astensor(X) + if self.params["implementation"] != "keops" or not isinstance(X, ep.PyTorchTensor): + # I don't know a clever way to reuse a GPU tensor from tensorflow in pytorch + # without copying to/from the CPU. + X = X.numpy() + if self.params["implementation"] == "ckdtree": + # sklearn could handle this, but it is much slower + from scipy.spatial import cKDTree + + self.kdtree = cKDTree(X) + + if self.params["implementation"] == "sklearn" and self.metric != "precomputed": + # FIXME: sklearn badly handles "precomputed" + from sklearn.neighbors import NearestNeighbors + + nargs = { + k: v for k, v in self.params.items() if k in {"p", "n_jobs", "metric_params", "algorithm", "leaf_size"} + } + self.nn = NearestNeighbors(self.k, metric=self.metric, **nargs) + self.nn.fit(X) + + if self.params["implementation"] == "hnsw": + import hnswlib + + self.graph = hnswlib.Index("l2", len(X[0])) # Actually returns squared distances + self.graph.init_index( + len(X), **{k: v for k, v in self.params.items() if k in {"ef_construction", "M", "random_seed"}} + ) + n = self.params.get("num_threads") + if n is None: + n = self.params.get("n_jobs", 1) + self.params["num_threads"] = n + self.graph.add_items(X, num_threads=n) + + return self + + def transform(self, X): + """ + Args: + X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed". + + Returns: + numpy.array: if return_index, an array of shape (len(X), k) with the indices (in the argument + of :func:`fit`) of the k nearest neighbors to the points of X. If return_distance, an array of the + same shape with the distances to those neighbors. If both, a tuple with the two arrays, in this order. + """ + if self.params.get("enable_autodiff", False): + # pykeops does not support autodiff for kmin yet, but when it does in the future, + # we may want a special path. + import eagerpy as ep + + save_return_index = self.return_index + self.return_index = True + self.return_distance = False + self.params["enable_autodiff"] = False + try: + newX = ep.astensor(X) + if self.params["implementation"] != "keops" or ( + not isinstance(newX, ep.PyTorchTensor) and not isinstance(newX, ep.NumPyTensor) + ): + newX = newX.numpy() + else: + newX = newX.raw + neighbors = self.transform(newX) + finally: + self.return_index = save_return_index + self.return_distance = True + self.params["enable_autodiff"] = True + # We can implement more later as needed + assert self.metric == "minkowski" + p = self.params["p"] + Y = ep.astensor(self.ref_points) + neighbor_pts = Y[ + neighbors, + ] + diff = neighbor_pts - X[:, None, :] + if isinstance(diff, ep.PyTorchTensor): + # https://github.com/jonasrauber/eagerpy/issues/6 + distances = ep.astensor(diff.raw.norm(p, -1)) + else: + distances = diff.norms.lp(p, -1) + if self.return_index: + return neighbors, distances.raw + else: + return distances.raw + + metric = self.metric + k = self.k + + if metric == "precomputed": + # scikit-learn could handle that, but they insist on calling fit() with an unused square array, which is too unnatural. + if self.return_index: + n_jobs = self.params.get("n_jobs", 1) + # Supposedly numpy can be compiled with OpenMP and handle this, but nobody does that?! + if n_jobs == 1: + neighbors = numpy.argpartition(X, k - 1)[:, 0:k] + if self.params.get("sort_results", True): + X = numpy.take_along_axis(X, neighbors, axis=-1) + ngb_order = numpy.argsort(X, axis=-1) + neighbors = numpy.take_along_axis(neighbors, ngb_order, axis=-1) + else: + ngb_order = neighbors + if self.return_distance: + distances = numpy.take_along_axis(X, ngb_order, axis=-1) + return neighbors, distances + else: + return neighbors + else: + from joblib import Parallel, delayed, effective_n_jobs + from sklearn.utils import gen_even_slices + + slices = gen_even_slices(len(X), effective_n_jobs(-1)) + parallel = Parallel(backend="threading", n_jobs=-1) + if self.params.get("sort_results", True): + + def func(M): + neighbors = numpy.argpartition(M, k - 1)[:, 0:k] + Y = numpy.take_along_axis(M, neighbors, axis=-1) + ngb_order = numpy.argsort(Y, axis=-1) + return numpy.take_along_axis(neighbors, ngb_order, axis=-1) + + else: + + def func(M): + return numpy.argpartition(M, k - 1)[:, 0:k] + + neighbors = numpy.concatenate(parallel(delayed(func)(X[s]) for s in slices)) + if self.return_distance: + distances = numpy.take_along_axis(X, neighbors, axis=-1) + return neighbors, distances + else: + return neighbors + if self.return_distance: + n_jobs = self.params.get("n_jobs", 1) + if n_jobs == 1: + distances = numpy.partition(X, k - 1)[:, 0:k] + if self.params.get("sort_results"): + # partition is not guaranteed to sort the lower half, although it often does + distances.sort(axis=-1) + else: + from joblib import Parallel, delayed, effective_n_jobs + from sklearn.utils import gen_even_slices + + if self.params.get("sort_results"): + + def func(M): + # Not partitioning in place, because we should not modify the user's array? + r = numpy.partition(M, k - 1)[:, 0:k] + r.sort(axis=-1) + return r + + else: + func = lambda M: numpy.partition(M, k - 1)[:, 0:k] + slices = gen_even_slices(len(X), effective_n_jobs(-1)) + parallel = Parallel(backend="threading", n_jobs=-1) + distances = numpy.concatenate(parallel(delayed(func)(X[s]) for s in slices)) + return distances + return None + + if self.params["implementation"] == "hnsw": + ef = self.params.get("ef") + if ef is not None: + self.graph.set_ef(ef) + neighbors, distances = self.graph.knn_query(X, k, num_threads=self.params["num_threads"]) + # The k nearest neighbors are always sorted. I couldn't find it in the doc, but the code calls searchKnn, + # which returns a priority_queue, and then fills the return array backwards with top/pop on the queue. + if self.return_index: + if self.return_distance: + return neighbors, numpy.sqrt(distances) + else: + return neighbors + if self.return_distance: + return numpy.sqrt(distances) + return None + + if self.params["implementation"] == "keops": + import torch + from pykeops.torch import LazyTensor + + # 'float64' is slow except on super expensive GPUs. Allow it with some param? + XX = torch.as_tensor(X, dtype=torch.float32) + if X is self.ref_points: + YY = XX + else: + YY = torch.as_tensor(self.ref_points, dtype=torch.float32) + p = self.params["p"] + if p == numpy.inf: + # Requires pykeops 1.4 or later + mat = (LazyTensor(XX[:, None, :]) - LazyTensor(YY[None, :, :])).abs().max(-1) + elif p == 2: # Any even integer? + mat = ((LazyTensor(XX[:, None, :]) - LazyTensor(YY[None, :, :])) ** p).sum(-1) + else: + mat = ((LazyTensor(XX[:, None, :]) - LazyTensor(YY[None, :, :])).abs() ** p).sum(-1) + + if self.return_index: + if self.return_distance: + distances, neighbors = mat.Kmin_argKmin(k, dim=1) + if p != numpy.inf: + distances = distances ** (1.0 / p) + return neighbors, distances + else: + neighbors = mat.argKmin(k, dim=1) + return neighbors + if self.return_distance: + distances = mat.Kmin(k, dim=1) + if p != numpy.inf: + distances = distances ** (1.0 / p) + return distances + return None + + if self.params["implementation"] == "ckdtree": + qargs = {key: val for key, val in self.params.items() if key in {"p", "eps", "n_jobs"}} + distances, neighbors = self.kdtree.query(X, k=self.k, **qargs) + if self.return_index: + if self.return_distance: + return neighbors, distances + else: + return neighbors + if self.return_distance: + return distances + return None + + assert self.params["implementation"] == "sklearn" + if self.return_distance: + distances, neighbors = self.nn.kneighbors(X, return_distance=True) + if self.return_index: + return neighbors, distances + else: + return distances + if self.return_index: + neighbors = self.nn.kneighbors(X, return_distance=False) + return neighbors + return None diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd index 82f155de..5dea2449 100644 --- a/src/python/gudhi/simplex_tree.pxd +++ b/src/python/gudhi/simplex_tree.pxd @@ -48,8 +48,7 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": int dimension() int upper_bound_dimension() bool find_simplex(vector[int] simplex) - bool insert_simplex_and_subfaces(vector[int] simplex, - double filtration) + bool insert(vector[int] simplex, double filtration) vector[pair[vector[int], double]] get_star(vector[int] simplex) vector[pair[vector[int], double]] get_cofaces(vector[int] simplex, int dimension) @@ -57,6 +56,8 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": void remove_maximal_simplex(vector[int] simplex) bool prune_above_filtration(double filtration) bool make_filtration_non_decreasing() + void compute_extended_filtration() + vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(vector[pair[int, pair[double, double]]] dgm, double min_persistence) # Iterators over Simplex tree pair[vector[int], double] get_simplex_and_filtration(Simplex_tree_simplex_handle f_simplex) Simplex_tree_simplices_iterator get_simplices_iterator_begin() @@ -69,9 +70,10 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi": cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface<Gudhi::Simplex_tree<Gudhi::Simplex_tree_options_full_featured>>": Simplex_tree_persistence_interface(Simplex_tree_interface_full_featured * st, bool persistence_dim_max) - vector[pair[int, pair[double, double]]] get_persistence(int homology_coeff_field, double min_persistence) + void compute_persistence(int homology_coeff_field, double min_persistence) + vector[pair[int, pair[double, double]]] get_persistence() vector[int] betti_numbers() vector[int] persistent_betti_numbers(double from_value, double to_value) vector[pair[double,double]] intervals_in_dimension(int dimension) - void write_output_diagram(string diagram_file_name) + void write_output_diagram(string diagram_file_name) except + vector[pair[vector[int], vector[int]]] persistence_pairs() diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index c01cc905..9479118a 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -90,7 +90,7 @@ cdef class SimplexTree: (with more :meth:`assign_filtration` or :meth:`make_filtration_non_decreasing` for instance) before calling any function that relies on the filtration property, like - :meth:`initialize_filtration`. + :meth:`persistence`. """ self.get_ptr().assign_simplex_filtration(simplex, filtration) @@ -98,16 +98,7 @@ cdef class SimplexTree: """This function initializes and sorts the simplicial complex filtration vector. - .. note:: - - This function must be launched before - :func:`persistence()<gudhi.SimplexTree.persistence>`, - :func:`betti_numbers()<gudhi.SimplexTree.betti_numbers>`, - :func:`persistent_betti_numbers()<gudhi.SimplexTree.persistent_betti_numbers>`, - or :func:`get_filtration()<gudhi.SimplexTree.get_filtration>` - after :func:`inserting<gudhi.SimplexTree.insert>` or - :func:`removing<gudhi.SimplexTree.remove_maximal_simplex>` - simplices. + .. deprecated:: 3.2.0 """ self.get_ptr().initialize_filtration() @@ -139,9 +130,9 @@ cdef class SimplexTree: This function is not constant time because it can recompute dimension if required (can be triggered by - :func:`remove_maximal_simplex()<gudhi.SimplexTree.remove_maximal_simplex>` + :func:`remove_maximal_simplex` or - :func:`prune_above_filtration()<gudhi.SimplexTree.prune_above_filtration>` + :func:`prune_above_filtration` methods). """ return self.get_ptr().dimension() @@ -166,9 +157,9 @@ cdef class SimplexTree: This function must be used with caution because it disables dimension recomputation when required (this recomputation can be triggered by - :func:`remove_maximal_simplex()<gudhi.SimplexTree.remove_maximal_simplex>` + :func:`remove_maximal_simplex` or - :func:`prune_above_filtration()<gudhi.SimplexTree.prune_above_filtration>` + :func:`prune_above_filtration` ). """ self.get_ptr().set_dimension(<int>dimension) @@ -182,10 +173,7 @@ cdef class SimplexTree: :returns: true if the simplex was found, false otherwise. :rtype: bool """ - cdef vector[int] csimplex - for i in simplex: - csimplex.push_back(i) - return self.get_ptr().find_simplex(csimplex) + return self.get_ptr().find_simplex(simplex) def insert(self, simplex, filtration=0.0): """This function inserts the given N-simplex and its subfaces with the @@ -202,11 +190,7 @@ cdef class SimplexTree: otherwise (whatever its original filtration value). :rtype: bool """ - cdef vector[int] csimplex - for i in simplex: - csimplex.push_back(i) - return self.get_ptr().insert_simplex_and_subfaces(csimplex, - <double>filtration) + return self.get_ptr().insert(simplex, <double>filtration) def get_simplices(self): """This function returns a generator with simplices and their given @@ -308,17 +292,12 @@ cdef class SimplexTree: .. note:: - Be aware that removing is shifting data in a flat_map - (:func:`initialize_filtration()<gudhi.SimplexTree.initialize_filtration>` to be done). - - .. note:: - The dimension of the simplicial complex may be lower after calling remove_maximal_simplex than it was before. However, - :func:`upper_bound_dimension()<gudhi.SimplexTree.upper_bound_dimension>` + :func:`upper_bound_dimension` method will return the old value, which remains a valid upper bound. If you care, you can call - :func:`dimension()<gudhi.SimplexTree.dimension>` + :func:`dimension` to recompute the exact dimension. """ self.get_ptr().remove_maximal_simplex(simplex) @@ -334,24 +313,14 @@ cdef class SimplexTree: .. note:: - Some simplex tree functions require the filtration to be valid. - prune_above_filtration function is not launching - :func:`initialize_filtration()<gudhi.SimplexTree.initialize_filtration>` - but returns the filtration modification - information. If the complex has changed , please call - :func:`initialize_filtration()<gudhi.SimplexTree.initialize_filtration>` - to recompute it. - - .. note:: - Note that the dimension of the simplicial complex may be lower after calling - :func:`prune_above_filtration()<gudhi.SimplexTree.prune_above_filtration>` + :func:`prune_above_filtration` than it was before. However, - :func:`upper_bound_dimension()<gudhi.SimplexTree.upper_bound_dimension>` + :func:`upper_bound_dimension` will return the old value, which remains a valid upper bound. If you care, you can call - :func:`dimension()<gudhi.SimplexTree.dimension>` + :func:`dimension` method to recompute the exact dimension. """ return self.get_ptr().prune_above_filtration(filtration) @@ -382,22 +351,63 @@ cdef class SimplexTree: :returns: True if any filtration value was modified, False if the filtration was already non-decreasing. :rtype: bool + """ + return self.get_ptr().make_filtration_non_decreasing() + def extend_filtration(self): + """ Extend filtration for computing extended persistence. This function only uses the + filtration values at the 0-dimensional simplices, and computes the extended persistence + diagram induced by the lower-star filtration computed with these values. .. note:: - Some simplex tree functions require the filtration to be valid. - make_filtration_non_decreasing function is not launching - :func:`initialize_filtration()<gudhi.SimplexTree.initialize_filtration>` - but returns the filtration modification - information. If the complex has changed , please call - :func:`initialize_filtration()<gudhi.SimplexTree.initialize_filtration>` - to recompute it. + Note that after calling this function, the filtration + values are actually modified within the Simplex_tree. + The function :func:`extended_persistence` + retrieves the original values. + + .. note:: + + Note that this code creates an extra vertex internally, so you should make sure that + the Simplex_tree does not contain a vertex with the largest possible value (i.e., 4294967295). """ - return self.get_ptr().make_filtration_non_decreasing() + self.get_ptr().compute_extended_filtration() + + def extended_persistence(self, homology_coeff_field=11, min_persistence=0): + """This function retrieves good values for extended persistence, and separate the diagrams + into the Ordinary, Relative, Extended+ and Extended- subdiagrams. + + :param homology_coeff_field: The homology coefficient field. Must be a + prime number. Default value is 11. + :type homology_coeff_field: int. + :param min_persistence: The minimum persistence value (i.e., the absolute value of the difference between the persistence diagram point coordinates) to take into + account (strictly greater than min_persistence). Default value is + 0.0. + Sets min_persistence to -1.0 to see all values. + :type min_persistence: float. + :returns: A list of four persistence diagrams in the format described in :func:`persistence`. The first one is Ordinary, the second one is Relative, the third one is Extended+ and the fourth one is Extended-. See https://link.springer.com/article/10.1007/s10208-008-9027-z and/or section 2.2 in https://link.springer.com/article/10.1007/s10208-017-9370-z for a description of these subtypes. + + .. note:: + + This function should be called only if :func:`extend_filtration` has been called first! + + .. note:: + + The coordinates of the persistence diagram points might be a little different than the + original filtration values due to the internal transformation (scaling to [-2,-1]) that is + performed on these values during the computation of extended persistence. + """ + cdef vector[pair[int, pair[double, double]]] persistence_result + if self.pcohptr != NULL: + del self.pcohptr + self.pcohptr = new Simplex_tree_persistence_interface(self.get_ptr(), False) + self.pcohptr.compute_persistence(homology_coeff_field, -1.) + persistence_result = self.pcohptr.get_persistence() + return self.get_ptr().compute_extended_persistence_subdiagrams(persistence_result, min_persistence) + def persistence(self, homology_coeff_field=11, min_persistence=0, persistence_dim_max = False): - """This function returns the persistence of the simplicial complex. + """This function computes and returns the persistence of the simplicial complex. :param homology_coeff_field: The homology coefficient field. Must be a prime number. Default value is 11. @@ -414,13 +424,32 @@ cdef class SimplexTree: :returns: The persistence of the simplicial complex. :rtype: list of pairs(dimension, pair(birth, death)) """ + self.compute_persistence(homology_coeff_field, min_persistence, persistence_dim_max) + return self.pcohptr.get_persistence() + + def compute_persistence(self, homology_coeff_field=11, min_persistence=0, persistence_dim_max = False): + """This function computes the persistence of the simplicial complex, so it can be accessed through + :func:`persistent_betti_numbers`, :func:`persistence_pairs`, etc. This function is equivalent to :func:`persistence` + when you do not want the list :func:`persistence` returns. + + :param homology_coeff_field: The homology coefficient field. Must be a + prime number. Default value is 11. + :type homology_coeff_field: int. + :param min_persistence: The minimum persistence value to take into + account (strictly greater than min_persistence). Default value is + 0.0. + Sets min_persistence to -1.0 to see all values. + :type min_persistence: float. + :param persistence_dim_max: If true, the persistent homology for the + maximal dimension in the complex is computed. If false, it is + ignored. Default is false. + :type persistence_dim_max: bool + :returns: Nothing. + """ if self.pcohptr != NULL: del self.pcohptr self.pcohptr = new Simplex_tree_persistence_interface(self.get_ptr(), persistence_dim_max) - cdef vector[pair[int, pair[double, double]]] persistence_result - if self.pcohptr != NULL: - persistence_result = self.pcohptr.get_persistence(homology_coeff_field, min_persistence) - return persistence_result + self.pcohptr.compute_persistence(homology_coeff_field, min_persistence) def betti_numbers(self): """This function returns the Betti numbers of the simplicial complex. @@ -429,16 +458,11 @@ cdef class SimplexTree: :rtype: list of int :note: betti_numbers function requires - :func:`persistence()<gudhi.SimplexTree.persistence>` + :func:`compute_persistence` function to be launched first. """ - cdef vector[int] bn_result - if self.pcohptr != NULL: - bn_result = self.pcohptr.betti_numbers() - else: - print("betti_numbers function requires persistence function" - " to be launched first.") - return bn_result + assert self.pcohptr != NULL, "compute_persistence() must be called before betti_numbers()" + return self.pcohptr.betti_numbers() def persistent_betti_numbers(self, from_value, to_value): """This function returns the persistent Betti numbers of the @@ -455,16 +479,11 @@ cdef class SimplexTree: :rtype: list of int :note: persistent_betti_numbers function requires - :func:`persistence()<gudhi.SimplexTree.persistence>` + :func:`compute_persistence` function to be launched first. """ - cdef vector[int] pbn_result - if self.pcohptr != NULL: - pbn_result = self.pcohptr.persistent_betti_numbers(<double>from_value, <double>to_value) - else: - print("persistent_betti_numbers function requires persistence function" - " to be launched first.") - return pbn_result + assert self.pcohptr != NULL, "compute_persistence() must be called before persistent_betti_numbers()" + return self.pcohptr.persistent_betti_numbers(<double>from_value, <double>to_value) def persistence_intervals_in_dimension(self, dimension): """This function returns the persistence intervals of the simplicial @@ -476,16 +495,11 @@ cdef class SimplexTree: :rtype: numpy array of dimension 2 :note: intervals_in_dim function requires - :func:`persistence()<gudhi.SimplexTree.persistence>` + :func:`compute_persistence` function to be launched first. """ - cdef vector[pair[double,double]] intervals_result - if self.pcohptr != NULL: - intervals_result = self.pcohptr.intervals_in_dimension(dimension) - else: - print("intervals_in_dim function requires persistence function" - " to be launched first.") - return np_array(intervals_result) + assert self.pcohptr != NULL, "compute_persistence() must be called before persistence_intervals_in_dimension()" + return np_array(self.pcohptr.intervals_in_dimension(dimension)) def persistence_pairs(self): """This function returns a list of persistence birth and death simplices pairs. @@ -494,33 +508,22 @@ cdef class SimplexTree: :rtype: list of pair of list of int :note: persistence_pairs function requires - :func:`persistence()<gudhi.SimplexTree.persistence>` + :func:`compute_persistence` function to be launched first. """ - cdef vector[pair[vector[int],vector[int]]] persistence_pairs_result - if self.pcohptr != NULL: - persistence_pairs_result = self.pcohptr.persistence_pairs() - else: - print("persistence_pairs function requires persistence function" - " to be launched first.") - return persistence_pairs_result + assert self.pcohptr != NULL, "compute_persistence() must be called before persistence_pairs()" + return self.pcohptr.persistence_pairs() - def write_persistence_diagram(self, persistence_file=''): + def write_persistence_diagram(self, persistence_file): """This function writes the persistence intervals of the simplicial complex in a user given file name. - :param persistence_file: The specific dimension. + :param persistence_file: Name of the file. :type persistence_file: string. :note: intervals_in_dim function requires - :func:`persistence()<gudhi.SimplexTree.persistence>` + :func:`compute_persistence` function to be launched first. """ - if self.pcohptr != NULL: - if persistence_file != '': - self.pcohptr.write_output_diagram(persistence_file.encode('utf-8')) - else: - print("persistence_file must be specified") - else: - print("intervals_in_dim function requires persistence function" - " to be launched first.") + assert self.pcohptr != NULL, "compute_persistence() must be called before write_persistence_diagram()" + self.pcohptr.write_output_diagram(persistence_file.encode('utf-8')) diff --git a/src/python/gudhi/wasserstein/__init__.py b/src/python/gudhi/wasserstein/__init__.py new file mode 100644 index 00000000..ed225ba4 --- /dev/null +++ b/src/python/gudhi/wasserstein/__init__.py @@ -0,0 +1 @@ +from .wasserstein import wasserstein_distance diff --git a/src/python/gudhi/wasserstein/barycenter.py b/src/python/gudhi/wasserstein/barycenter.py new file mode 100644 index 00000000..de7aea81 --- /dev/null +++ b/src/python/gudhi/wasserstein/barycenter.py @@ -0,0 +1,159 @@ +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# Author(s): Theo Lacombe +# +# Copyright (C) 2019 Inria +# +# Modification(s): +# - YYYY/MM Author: Description of the modification + + +import ot +import numpy as np +import scipy.spatial.distance as sc + +from gudhi.wasserstein import wasserstein_distance + + +def _mean(x, m): + ''' + :param x: a list of 2D-points, off diagonal, x_0... x_{k-1} + :param m: total amount of points taken into account, + that is we have (m-k) copies of diagonal + :returns: the weighted mean of x with (m-k) copies of the diagonal + ''' + k = len(x) + if k > 0: + w = np.mean(x, axis=0) + w_delta = (w[0] + w[1]) / 2 * np.ones(2) + return (k * w + (m-k) * w_delta) / m + else: + return np.array([0, 0]) + + +def lagrangian_barycenter(pdiagset, init=None, verbose=False): + ''' + :param pdiagset: a list of ``numpy.array`` of shape `(n x 2)` + (`n` can variate), encoding a set of + persistence diagrams with only finite coordinates. + :param init: The initial value for barycenter estimate. + If ``None``, init is made on a random diagram from the dataset. + Otherwise, it can be an ``int`` + (then initialization is made on ``pdiagset[init]``) + or a `(n x 2)` ``numpy.array`` enconding + a persistence diagram with `n` points. + :type init: ``int``, or (n x 2) ``np.array`` + :param verbose: if ``True``, returns additional information about the + barycenter. + :type verbose: boolean + :returns: If not verbose (default), a ``numpy.array`` encoding + the barycenter estimate of pdiagset + (local minimum of the energy function). + If ``pdiagset`` is empty, returns ``None``. + If verbose, returns a couple ``(Y, log)`` + where ``Y`` is the barycenter estimate, + and ``log`` is a ``dict`` that contains additional informations: + + - `"groupings"`, a list of list of pairs ``(i,j)``. + Namely, ``G[k] = [...(i, j)...]``, where ``(i,j)`` indicates + that ``pdiagset[k][i]`` is matched to ``Y[j]`` + if ``i = -1`` or ``j = -1``, it means they + represent the diagonal. + + - `"energy"`, ``float`` representing the Frechet energy value obtained. + It is the mean of squared distances of observations to the output. + + - `"nb_iter"`, ``int`` number of iterations performed before convergence of the algorithm. + ''' + X = pdiagset # to shorten notations, not a copy + m = len(X) # number of diagrams we are averaging + if m == 0: + print("Warning: computing barycenter of empty diag set. Returns None") + return None + + # store the number of off-diagonal point for each of the X_i + nb_off_diag = np.array([len(X_i) for X_i in X]) + # Initialisation of barycenter + if init is None: + i0 = np.random.randint(m) # Index of first state for the barycenter + Y = X[i0].copy() + else: + if type(init)==int: + Y = X[init].copy() + else: + Y = init.copy() + + nb_iter = 0 + + converged = False # stoping criterion + while not converged: + nb_iter += 1 + K = len(Y) # current nb of points in Y (some might be on diagonal) + G = np.full((K, m), -1, dtype=int) # will store for each j, the (index) + # point matched in each other diagram + #(might be the diagonal). + # that is G[j, i] = k <=> y_j is matched to + # x_k in the diagram i-th diagram X[i] + updated_points = np.zeros((K, 2)) # will store the new positions of + # the points of Y. + # If points disappear, there thrown + # on [0,0] by default. + new_created_points = [] # will store potential new points. + + # Step 1 : compute optimal matching (Y, X_i) for each X_i + # and create new points in Y if needed + for i in range(m): + _, indices = wasserstein_distance(Y, X[i], matching=True, order=2., internal_p=2.) + for y_j, x_i_j in indices: + if y_j >= 0: # we matched an off diagonal point to x_i_j... + if x_i_j >= 0: # ...which is also an off-diagonal point. + G[y_j, i] = x_i_j + else: # ...which is a diagonal point + G[y_j, i] = -1 # -1 stands for the diagonal (mask) + else: # We matched a diagonal point to x_i_j... + if x_i_j >= 0: # which is a off-diag point ! + # need to create new point in Y + new_y = _mean(np.array([X[i][x_i_j]]), m) + # Average this point with (m-1) copies of Delta + new_created_points.append(new_y) + + # Step 2 : Update current point position thanks to groupings computed + to_delete = [] + for j in range(K): + matched_points = [X[i][G[j, i]] for i in range(m) if G[j, i] > -1] + new_y_j = _mean(matched_points, m) + if not np.array_equal(new_y_j, np.array([0,0])): + updated_points[j] = new_y_j + else: # this points is no longer of any use. + to_delete.append(j) + # we remove the point to be deleted now. + updated_points = np.delete(updated_points, to_delete, axis=0) + + # we cannot converge if there have been new created points. + if new_created_points: + Y = np.concatenate((updated_points, new_created_points)) + else: + # Step 3 : we check convergence + if np.array_equal(updated_points, Y): + converged = True + Y = updated_points + + + if verbose: + groupings = [] + energy = 0 + log = {} + n_y = len(Y) + for i in range(m): + cost, edges = wasserstein_distance(Y, X[i], matching=True, order=2., internal_p=2.) + groupings.append(edges) + energy += cost + log["groupings"] = groupings + energy = energy/m + log["energy"] = energy + log["nb_iter"] = nb_iter + + return Y, log + else: + return Y + diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 3dd993f9..efc851a0 100644 --- a/src/python/gudhi/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -9,11 +9,14 @@ import numpy as np import scipy.spatial.distance as sc + try: import ot except ImportError: print("POT (Python Optimal Transport) package is not installed. Try to run $ conda install -c conda-forge pot ; or $ pip install POT") + +# Currently unused, but Théo says it is likely to be used again. def _proj_on_diag(X): ''' :param X: (n x 2) array encoding the points of a persistent diagram. @@ -23,28 +26,36 @@ def _proj_on_diag(X): return np.array([Z , Z]).T -def _build_dist_matrix(X, Y, order=2., internal_p=2.): +def _dist_to_diag(X, internal_p): + ''' + :param X: (n x 2) array encoding the points of a persistent diagram. + :param internal_p: Ground metric (i.e. norm L^p). + :returns: (n) array encoding the (respective orthogonal) distances of the points to the diagonal + + .. note:: + Assumes that the points are above the diagonal. + ''' + return (X[:, 1] - X[:, 0]) * 2 ** (1.0 / internal_p - 1) + + +def _build_dist_matrix(X, Y, order, internal_p): ''' :param X: (n x 2) numpy.array encoding the (points of the) first diagram. :param Y: (m x 2) numpy.array encoding the second diagram. :param order: exponent for the Wasserstein metric. :param internal_p: Ground metric (i.e. norm L^p). - :returns: (n+1) x (m+1) np.array encoding the cost matrix C. - For 0 <= i < n, 0 <= j < m, C[i,j] encodes the distance between X[i] and Y[j], - while C[i, m] (resp. C[n, j]) encodes the distance (to the p) between X[i] (resp Y[j]) + :returns: (n+1) x (m+1) np.array encoding the cost matrix C. + For 0 <= i < n, 0 <= j < m, C[i,j] encodes the distance between X[i] and Y[j], + while C[i, m] (resp. C[n, j]) encodes the distance (to the p) between X[i] (resp Y[j]) and its orthogonal projection onto the diagonal. note also that C[n, m] = 0 (it costs nothing to move from the diagonal to the diagonal). ''' - Xdiag = _proj_on_diag(X) - Ydiag = _proj_on_diag(Y) + Cxd = _dist_to_diag(X, internal_p)**order + Cdy = _dist_to_diag(Y, internal_p)**order if np.isinf(internal_p): C = sc.cdist(X,Y, metric='chebyshev')**order - Cxd = np.linalg.norm(X - Xdiag, ord=internal_p, axis=1)**order - Cdy = np.linalg.norm(Y - Ydiag, ord=internal_p, axis=1)**order else: C = sc.cdist(X,Y, metric='minkowski', p=internal_p)**order - Cxd = np.linalg.norm(X - Xdiag, ord=internal_p, axis=1)**order - Cdy = np.linalg.norm(Y - Ydiag, ord=internal_p, axis=1)**order Cf = np.hstack((C, Cxd[:,None])) Cdy = np.append(Cdy, 0) @@ -58,24 +69,23 @@ def _perstot(X, order, internal_p): :param X: (n x 2) numpy.array (points of a given diagram). :param order: exponent for Wasserstein. Default value is 2. :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). - :returns: float, the total persistence of the diagram (that is, its distance to the empty diagram). + :returns: float, the total persistence of the diagram (that is, its distance to the empty diagram). ''' - Xdiag = _proj_on_diag(X) - return (np.sum(np.linalg.norm(X - Xdiag, ord=internal_p, axis=1)**order))**(1./order) + return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order) def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.): ''' - :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points + :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate). :param Y: (m x 2) numpy.array encoding the second diagram. :param matching: if True, computes and returns the optimal matching between X and Y, encoded as a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to the j-th point in Y, with the convention (-1) represents the diagonal. :param order: exponent for Wasserstein; Default value is 2. - :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); + :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). - :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with + :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with respect to the internal_p-norm as ground metric. If matching is set to True, also returns the optimal matching between X and Y. ''' |