diff options
author | Vincent Rouvreau <vincent.rouvreau@inria.fr> | 2022-06-17 16:33:54 +0200 |
---|---|---|
committer | Vincent Rouvreau <vincent.rouvreau@inria.fr> | 2022-06-17 16:33:54 +0200 |
commit | fc910c1bcb3451bcf3e288db25fecafe15cc42bb (patch) | |
tree | 1be9b439b0e74eb04c873612dbca28990284f24c /src/python | |
parent | 048fff97cd0a53be5953c4d5799f8e2e097c181c (diff) | |
parent | 854ae4169ece5edfaae15526f42314b0976e2b84 (diff) |
Merge master
Diffstat (limited to 'src/python')
18 files changed, 411 insertions, 55 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 4b04c487..63a2e006 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -330,9 +330,9 @@ if(PYTHONINTERP_FOUND) if(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 5.1.0) set (GUDHI_SPHINX_MESSAGE "Generating API documentation with Sphinx in ${CMAKE_CURRENT_BINARY_DIR}/sphinx/") # User warning - Sphinx is a static pages generator, and configured to work fine with user_version - # Images and biblio warnings because not found on developper version + # Images and biblio warnings because not found on developer version if (GUDHI_PYTHON_PATH STREQUAL "src/python") - set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developper version. Images and biblio will miss") + set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developer version. Images and biblio will miss") endif() # sphinx target requires gudhi.so, because conf.py reads gudhi version from it add_custom_target(sphinx @@ -485,7 +485,7 @@ if(PYTHONINTERP_FOUND) add_gudhi_py_test(test_euclidean_witness_complex) # Datasets generators - add_gudhi_py_test(test_datasets_generators) # TODO separate full python datasets generators in another test file independant from CGAL ? + add_gudhi_py_test(test_datasets_generators) # TODO separate full python datasets generators in another test file independent from CGAL ? endif (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) @@ -597,6 +597,11 @@ if(PYTHONINTERP_FOUND) add_gudhi_py_test(test_dtm_rips_complex) endif() + # Fetch remote datasets + if(WITH_GUDHI_REMOTE_TEST) + add_gudhi_py_test(test_remote_datasets) + endif() + # sklearn if(SKLEARN_FOUND) add_gudhi_py_test(test_sklearn_cubical_persistence) diff --git a/src/python/doc/datasets_generators.inc b/src/python/doc/datasets.inc index 8d169275..95a87678 100644 --- a/src/python/doc/datasets_generators.inc +++ b/src/python/doc/datasets.inc @@ -2,7 +2,7 @@ :widths: 30 40 30 +-----------------------------------+--------------------------------------------+--------------------------------------------------------------------------------------+ - | .. figure:: | Datasets generators (points). | :Authors: Hind Montassif | + | .. figure:: | Datasets either generated or fetched. | :Authors: Hind Montassif | | img/sphere_3d.png | | | | | | :Since: GUDHI 3.5.0 | | | | | @@ -10,5 +10,5 @@ | | | | | | | :Requires: `CGAL <installation.html#cgal>`_ | +-----------------------------------+--------------------------------------------+--------------------------------------------------------------------------------------+ - | * :doc:`datasets_generators` | + | * :doc:`datasets` | +-----------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+ diff --git a/src/python/doc/datasets_generators.rst b/src/python/doc/datasets.rst index 260c3882..2d11a19d 100644 --- a/src/python/doc/datasets_generators.rst +++ b/src/python/doc/datasets.rst @@ -3,12 +3,14 @@ .. To get rid of WARNING: document isn't included in any toctree -=========================== -Datasets generators manual -=========================== +================ +Datasets manual +================ -We provide the generation of different customizable datasets to use as inputs for Gudhi complexes and data structures. +Datasets generators +=================== +We provide the generation of different customizable datasets to use as inputs for Gudhi complexes and data structures. Points generators ------------------ @@ -103,3 +105,29 @@ Example .. autofunction:: gudhi.datasets.generators.points.torus + + +Fetching datasets +================= + +We provide some ready-to-use datasets that are not available by default when getting GUDHI, and need to be fetched explicitly. + +By **default**, the fetched datasets directory is set to a folder named **'gudhi_data'** in the **user home folder**. +Alternatively, it can be set using the **'GUDHI_DATA'** environment variable. + +.. autofunction:: gudhi.datasets.remote.fetch_bunny + +.. figure:: ./img/bunny.png + :figclass: align-center + + 3D Stanford bunny with 35947 vertices. + + +.. autofunction:: gudhi.datasets.remote.fetch_spiral_2d + +.. figure:: ./img/spiral_2d.png + :figclass: align-center + + 2D spiral with 114562 vertices. + +.. autofunction:: gudhi.datasets.remote.clear_data_home diff --git a/src/python/doc/img/bunny.png b/src/python/doc/img/bunny.png Binary files differnew file mode 100644 index 00000000..769aa530 --- /dev/null +++ b/src/python/doc/img/bunny.png diff --git a/src/python/doc/img/spiral_2d.png b/src/python/doc/img/spiral_2d.png Binary files differnew file mode 100644 index 00000000..abd247cd --- /dev/null +++ b/src/python/doc/img/spiral_2d.png diff --git a/src/python/doc/index.rst b/src/python/doc/index.rst index 2d7921ae..35f4ba46 100644 --- a/src/python/doc/index.rst +++ b/src/python/doc/index.rst @@ -92,7 +92,7 @@ Clustering .. include:: clustering.inc -Datasets generators -******************* +Datasets +******** -.. include:: datasets_generators.inc +.. include:: datasets.inc diff --git a/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py b/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py index ea2eb7e1..0b35dbc5 100755 --- a/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py +++ b/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py @@ -40,7 +40,7 @@ parser.add_argument( args = parser.parse_args() if not (-1.0 < args.min_edge_correlation < 1.0): - print("Wrong value of the treshold corelation (should be between -1 and 1).") + print("Wrong value of the threshold corelation (should be between -1 and 1).") sys.exit(1) print("#####################################################################") diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py new file mode 100644 index 00000000..f6d3fe56 --- /dev/null +++ b/src/python/gudhi/datasets/remote.py @@ -0,0 +1,223 @@ +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# Author(s): Hind Montassif +# +# Copyright (C) 2021 Inria +# +# Modification(s): +# - YYYY/MM Author: Description of the modification + +from os.path import join, split, exists, expanduser +from os import makedirs, remove, environ + +from urllib.request import urlretrieve +import hashlib +import shutil + +import numpy as np + +def _get_data_home(data_home = None): + """ + Return the path of the remote datasets directory. + This folder is used to store remotely fetched datasets. + By default the datasets directory is set to a folder named 'gudhi_data' in the user home folder. + Alternatively, it can be set by the 'GUDHI_DATA' environment variable. + The '~' symbol is expanded to the user home folder. + If the folder does not already exist, it is automatically created. + + Parameters + ---------- + data_home : string + The path to remote datasets directory. + Default is `None`, meaning that the data home directory will be set to "~/gudhi_data", + if the 'GUDHI_DATA' environment variable does not exist. + + Returns + ------- + data_home: string + The path to remote datasets directory. + """ + if data_home is None: + data_home = environ.get("GUDHI_DATA", join("~", "gudhi_data")) + data_home = expanduser(data_home) + makedirs(data_home, exist_ok=True) + return data_home + + +def clear_data_home(data_home = None): + """ + Delete the data home cache directory and all its content. + + Parameters + ---------- + data_home : string, default is None. + The path to remote datasets directory. + If `None` and the 'GUDHI_DATA' environment variable does not exist, + the default directory to be removed is set to "~/gudhi_data". + """ + data_home = _get_data_home(data_home) + shutil.rmtree(data_home) + +def _checksum_sha256(file_path): + """ + Compute the file checksum using sha256. + + Parameters + ---------- + file_path: string + Full path of the created file including filename. + + Returns + ------- + The hex digest of file_path. + """ + sha256_hash = hashlib.sha256() + chunk_size = 4096 + with open(file_path,"rb") as f: + # Read and update hash string value in blocks of 4K + while True: + buffer = f.read(chunk_size) + if not buffer: + break + sha256_hash.update(buffer) + return sha256_hash.hexdigest() + +def _fetch_remote(url, file_path, file_checksum = None): + """ + Fetch the wanted dataset from the given url and save it in file_path. + + Parameters + ---------- + url : string + The url to fetch the dataset from. + file_path : string + Full path of the downloaded file including filename. + file_checksum : string + The file checksum using sha256 to check against the one computed on the downloaded file. + Default is 'None', which means the checksum is not checked. + + Raises + ------ + IOError + If the computed SHA256 checksum of file does not match the one given by the user. + """ + + # Get the file + urlretrieve(url, file_path) + + if file_checksum is not None: + checksum = _checksum_sha256(file_path) + if file_checksum != checksum: + # Remove file and raise error + remove(file_path) + raise IOError("{} has a SHA256 checksum : {}, " + "different from expected : {}." + "The file may be corrupted or the given url may be wrong !".format(file_path, checksum, file_checksum)) + +def _get_archive_path(file_path, label): + """ + Get archive path based on file_path given by user and label. + + Parameters + ---------- + file_path: string + Full path of the file to get including filename, or None. + label: string + Label used along with 'data_home' to get archive path, in case 'file_path' is None. + + Returns + ------- + Full path of archive including filename. + """ + if file_path is None: + archive_path = join(_get_data_home(), label) + dirname = split(archive_path)[0] + makedirs(dirname, exist_ok=True) + else: + archive_path = file_path + dirname = split(archive_path)[0] + makedirs(dirname, exist_ok=True) + + return archive_path + +def fetch_spiral_2d(file_path = None): + """ + Load the spiral_2d dataset. + + Note that if the dataset already exists in the target location, it is not downloaded again, + and the corresponding array is returned from cache. + + Parameters + ---------- + file_path : string + Full path of the downloaded file including filename. + + Default is None, meaning that it's set to "data_home/points/spiral_2d/spiral_2d.npy". + + The "data_home" directory is set by default to "~/gudhi_data", + unless the 'GUDHI_DATA' environment variable is set. + + Returns + ------- + points: numpy array + Array of shape (114562, 2). + """ + file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy" + file_checksum = '2226024da76c073dd2f24b884baefbfd14928b52296df41ad2d9b9dc170f2401' + + archive_path = _get_archive_path(file_path, "points/spiral_2d/spiral_2d.npy") + + if not exists(archive_path): + _fetch_remote(file_url, archive_path, file_checksum) + + return np.load(archive_path, mmap_mode='r') + +def fetch_bunny(file_path = None, accept_license = False): + """ + Load the Stanford bunny dataset. + + This dataset contains 35947 vertices. + + Note that if the dataset already exists in the target location, it is not downloaded again, + and the corresponding array is returned from cache. + + Parameters + ---------- + file_path : string + Full path of the downloaded file including filename. + + Default is None, meaning that it's set to "data_home/points/bunny/bunny.npy". + In this case, the LICENSE file would be downloaded as "data_home/points/bunny/bunny.LICENSE". + + The "data_home" directory is set by default to "~/gudhi_data", + unless the 'GUDHI_DATA' environment variable is set. + + accept_license : boolean + Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms. + + Default is False. + + Returns + ------- + points: numpy array + Array of shape (35947, 3). + """ + + file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy" + file_checksum = 'f382482fd89df8d6444152dc8fd454444fe597581b193fd139725a85af4a6c6e' + license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.LICENSE" + license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a' + + archive_path = _get_archive_path(file_path, "points/bunny/bunny.npy") + + if not exists(archive_path): + _fetch_remote(file_url, archive_path, file_checksum) + license_path = join(split(archive_path)[0], "bunny.LICENSE") + _fetch_remote(license_url, license_path, license_checksum) + # Print license terms unless accept_license is set to True + if not accept_license: + if exists(license_path): + with open(license_path, 'r') as f: + print(f.read()) + + return np.load(archive_path, mmap_mode='r') diff --git a/src/python/gudhi/hera/wasserstein.cc b/src/python/gudhi/hera/wasserstein.cc index 1a21f02f..fa0cf8aa 100644 --- a/src/python/gudhi/hera/wasserstein.cc +++ b/src/python/gudhi/hera/wasserstein.cc @@ -29,7 +29,7 @@ double wasserstein_distance( if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>(); params.internal_p = internal_p; params.delta = delta; - // The extra parameters are purposedly not exposed for now. + // The extra parameters are purposely not exposed for now. return hera::wasserstein_dist(diag1, diag2, params); } diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py index 7ed11360..21275cdd 100644 --- a/src/python/gudhi/persistence_graphical_tools.py +++ b/src/python/gudhi/persistence_graphical_tools.py @@ -332,7 +332,7 @@ def plot_persistence_diagram( axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha) # Infinity label yt = axes.get_yticks() - yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity + yt = yt[np.where(yt < axis_end)] # to avoid plotting ticklabel higher than infinity yt = np.append(yt, infinity) ytl = ["%.3f" % e for e in yt] # to avoid float precision error ytl[-1] = r"$+\infty$" diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd index 4f229663..5642f82d 100644 --- a/src/python/gudhi/simplex_tree.pxd +++ b/src/python/gudhi/simplex_tree.pxd @@ -63,7 +63,6 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": bool prune_above_filtration(double filtration) nogil bool make_filtration_non_decreasing() nogil void compute_extended_filtration() nogil - vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(vector[pair[int, pair[double, double]]] dgm, double min_persistence) nogil Simplex_tree_interface_full_featured* collapse_edges(int nb_collapse_iteration) nogil except + void reset_filtration(double filtration, int dimension) nogil bint operator==(Simplex_tree_interface_full_featured) nogil @@ -81,7 +80,7 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": void expansion_with_blockers_callback(int dimension, blocker_func_t user_func, void *user_data) cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi": - cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface<Gudhi::Simplex_tree<Gudhi::Simplex_tree_options_full_featured>>": + cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface<Gudhi::Simplex_tree_interface<Gudhi::Simplex_tree_options_full_featured>>": Simplex_tree_persistence_interface(Simplex_tree_interface_full_featured * st, bool persistence_dim_max) nogil void compute_persistence(int homology_coeff_field, double min_persistence) nogil except + vector[pair[int, pair[double, double]]] get_persistence() nogil @@ -92,3 +91,4 @@ cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi": vector[pair[vector[int], vector[int]]] persistence_pairs() nogil pair[vector[vector[int]], vector[vector[int]]] lower_star_generators() nogil pair[vector[vector[int]], vector[vector[int]]] flag_generators() nogil + vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(double min_persistence) nogil diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 2c53a872..521a7763 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -474,8 +474,7 @@ cdef class SimplexTree: del self.pcohptr self.pcohptr = new Simplex_tree_persistence_interface(self.get_ptr(), False) self.pcohptr.compute_persistence(homology_coeff_field, -1.) - persistence_result = self.pcohptr.get_persistence() - return self.get_ptr().compute_extended_persistence_subdiagrams(persistence_result, min_persistence) + return self.pcohptr.compute_extended_persistence_subdiagrams(min_persistence) def expansion_with_blocker(self, max_dim, blocker_func): """Expands the Simplex_tree containing only a graph. Simplices corresponding to cliques in the graph are added diff --git a/src/python/gudhi/wasserstein/barycenter.py b/src/python/gudhi/wasserstein/barycenter.py index d67bcde7..bb6e641e 100644 --- a/src/python/gudhi/wasserstein/barycenter.py +++ b/src/python/gudhi/wasserstein/barycenter.py @@ -37,7 +37,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False): :param init: The initial value for barycenter estimate. If ``None``, init is made on a random diagram from the dataset. Otherwise, it can be an ``int`` (then initialization is made on ``pdiagset[init]``) - or a `(n x 2)` ``numpy.array`` enconding a persistence diagram with `n` points. + or a `(n x 2)` ``numpy.array`` encoding a persistence diagram with `n` points. :type init: ``int``, or (n x 2) ``np.array`` :param verbose: if ``True``, returns additional information about the barycenter. :type verbose: boolean @@ -45,7 +45,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False): (local minimum of the energy function). If ``pdiagset`` is empty, returns ``None``. If verbose, returns a couple ``(Y, log)`` where ``Y`` is the barycenter estimate, - and ``log`` is a ``dict`` that contains additional informations: + and ``log`` is a ``dict`` that contains additional information: - `"groupings"`, a list of list of pairs ``(i,j)``. Namely, ``G[k] = [...(i, j)...]``, where ``(i,j)`` indicates that `pdiagset[k][i]`` is matched to ``Y[j]`` if ``i = -1`` or ``j = -1``, it means they represent the diagonal. @@ -73,7 +73,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False): nb_iter = 0 - converged = False # stoping criterion + converged = False # stopping criterion while not converged: nb_iter += 1 K = len(Y) # current nb of points in Y (some might be on diagonal) diff --git a/src/python/include/Persistent_cohomology_interface.h b/src/python/include/Persistent_cohomology_interface.h index e5a3dfba..945378a0 100644 --- a/src/python/include/Persistent_cohomology_interface.h +++ b/src/python/include/Persistent_cohomology_interface.h @@ -12,6 +12,8 @@ #define INCLUDE_PERSISTENT_COHOMOLOGY_INTERFACE_H_ #include <gudhi/Persistent_cohomology.h> +#include <gudhi/Simplex_tree.h> // for Extended_simplex_type + #include <cstdlib> #include <vector> @@ -223,6 +225,44 @@ persistent_cohomology::Persistent_cohomology<FilteredComplex, persistent_cohomol return out; } + using Filtration_value = typename FilteredComplex::Filtration_value; + using Birth_death = std::pair<Filtration_value, Filtration_value>; + using Persistence_subdiagrams = std::vector<std::vector<std::pair<int, Birth_death>>>; + + Persistence_subdiagrams compute_extended_persistence_subdiagrams(Filtration_value min_persistence){ + Persistence_subdiagrams pers_subs(4); + auto const& persistent_pairs = Base::get_persistent_pairs(); + for (auto pair : persistent_pairs) { + std::pair<Filtration_value, Extended_simplex_type> px = stptr_->decode_extended_filtration(stptr_->filtration(get<0>(pair)), + stptr_->efd); + std::pair<Filtration_value, Extended_simplex_type> py = stptr_->decode_extended_filtration(stptr_->filtration(get<1>(pair)), + stptr_->efd); + std::pair<int, Birth_death> pd_point = std::make_pair(stptr_->dimension(get<0>(pair)), + std::make_pair(px.first, py.first)); + if(std::abs(px.first - py.first) > min_persistence){ + //Ordinary + if (px.second == Extended_simplex_type::UP && py.second == Extended_simplex_type::UP){ + pers_subs[0].push_back(pd_point); + } + // Relative + else if (px.second == Extended_simplex_type::DOWN && py.second == Extended_simplex_type::DOWN){ + pers_subs[1].push_back(pd_point); + } + else{ + // Extended+ + if (px.first < py.first){ + pers_subs[2].push_back(pd_point); + } + //Extended- + else{ + pers_subs[3].push_back(pd_point); + } + } + } + } + return pers_subs; + } + private: // A copy FilteredComplex* stptr_; diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h index 7f9b0067..3848c5ad 100644 --- a/src/python/include/Simplex_tree_interface.h +++ b/src/python/include/Simplex_tree_interface.h @@ -132,36 +132,6 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> { return; } - std::vector<std::vector<std::pair<int, std::pair<Filtration_value, Filtration_value>>>> compute_extended_persistence_subdiagrams(const std::vector<std::pair<int, std::pair<Filtration_value, Filtration_value>>>& dgm, Filtration_value min_persistence){ - std::vector<std::vector<std::pair<int, std::pair<Filtration_value, Filtration_value>>>> new_dgm(4); - for (unsigned int i = 0; i < dgm.size(); i++){ - std::pair<Filtration_value, Extended_simplex_type> px = this->decode_extended_filtration(dgm[i].second.first, this->efd); - std::pair<Filtration_value, Extended_simplex_type> py = this->decode_extended_filtration(dgm[i].second.second, this->efd); - std::pair<int, std::pair<Filtration_value, Filtration_value>> pd_point = std::make_pair(dgm[i].first, std::make_pair(px.first, py.first)); - if(std::abs(px.first - py.first) > min_persistence){ - //Ordinary - if (px.second == Extended_simplex_type::UP && py.second == Extended_simplex_type::UP){ - new_dgm[0].push_back(pd_point); - } - // Relative - else if (px.second == Extended_simplex_type::DOWN && py.second == Extended_simplex_type::DOWN){ - new_dgm[1].push_back(pd_point); - } - else{ - // Extended+ - if (px.first < py.first){ - new_dgm[2].push_back(pd_point); - } - //Extended- - else{ - new_dgm[3].push_back(pd_point); - } - } - } - } - return new_dgm; - } - Simplex_tree_interface* collapse_edges(int nb_collapse_iteration) { using Filtered_edge = std::tuple<Vertex_handle, Vertex_handle, Filtration_value>; std::vector<Filtered_edge> edges; diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py new file mode 100644 index 00000000..e5d2de82 --- /dev/null +++ b/src/python/test/test_remote_datasets.py @@ -0,0 +1,87 @@ +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# Author(s): Hind Montassif +# +# Copyright (C) 2021 Inria +# +# Modification(s): +# - YYYY/MM Author: Description of the modification + +from gudhi.datasets import remote + +import shutil +import io +import sys +import pytest + +from os.path import isdir, expanduser, exists +from os import remove, environ + +def test_data_home(): + # Test _get_data_home and clear_data_home on new empty folder + empty_data_home = remote._get_data_home(data_home="empty_folder_for_test") + assert isdir(empty_data_home) + + remote.clear_data_home(data_home=empty_data_home) + assert not isdir(empty_data_home) + +def test_fetch_remote(): + # Test fetch with a wrong checksum + with pytest.raises(OSError): + remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "tmp_spiral_2d.npy", file_checksum = 'XXXXXXXXXX') + assert not exists("tmp_spiral_2d.npy") + +def _get_bunny_license_print(accept_license = False): + capturedOutput = io.StringIO() + # Redirect stdout + sys.stdout = capturedOutput + + bunny_arr = remote.fetch_bunny("./tmp_for_test/bunny.npy", accept_license) + assert bunny_arr.shape == (35947, 3) + del bunny_arr + remove("./tmp_for_test/bunny.npy") + + # Reset redirect + sys.stdout = sys.__stdout__ + return capturedOutput + +def test_print_bunny_license(): + # Test not printing bunny.npy LICENSE when accept_license = True + assert "" == _get_bunny_license_print(accept_license = True).getvalue() + # Test printing bunny.LICENSE file when fetching bunny.npy with accept_license = False (default) + with open("./tmp_for_test/bunny.LICENSE") as f: + assert f.read().rstrip("\n") == _get_bunny_license_print().getvalue().rstrip("\n") + shutil.rmtree("./tmp_for_test") + +def test_fetch_remote_datasets_wrapped(): + # Test fetch_spiral_2d and fetch_bunny wrapping functions with data directory different from default (twice, to test case of already fetched files) + # Default case is not tested because it would fail in case the user sets the 'GUDHI_DATA' environment variable locally + for i in range(2): + spiral_2d_arr = remote.fetch_spiral_2d("./another_fetch_folder_for_test/spiral_2d.npy") + assert spiral_2d_arr.shape == (114562, 2) + + bunny_arr = remote.fetch_bunny("./another_fetch_folder_for_test/bunny.npy") + assert bunny_arr.shape == (35947, 3) + + # Check that the directory was created + assert isdir("./another_fetch_folder_for_test") + # Check downloaded files + assert exists("./another_fetch_folder_for_test/spiral_2d.npy") + assert exists("./another_fetch_folder_for_test/bunny.npy") + assert exists("./another_fetch_folder_for_test/bunny.LICENSE") + + # Remove test folders + del spiral_2d_arr + del bunny_arr + shutil.rmtree("./another_fetch_folder_for_test") + +def test_gudhi_data_env(): + # Set environment variable "GUDHI_DATA" + environ["GUDHI_DATA"] = "./test_folder_from_env_var" + bunny_arr = remote.fetch_bunny() + assert bunny_arr.shape == (35947, 3) + assert exists("./test_folder_from_env_var/points/bunny/bunny.npy") + assert exists("./test_folder_from_env_var/points/bunny/bunny.LICENSE") + # Remove test folder + del bunny_arr + shutil.rmtree("./test_folder_from_env_var") diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 688f4fd6..54bafed5 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -320,6 +320,10 @@ def test_extend_filtration(): ] dgms = st.extended_persistence(min_persistence=-1.) + assert len(dgms) == 4 + # Sort by (death-birth) descending - we are only interested in those with the longest life span + for idx in range(4): + dgms[idx] = sorted(dgms[idx], key=lambda x:(-abs(x[1][0]-x[1][1]))) assert dgms[0][0][1][0] == pytest.approx(2.) assert dgms[0][0][1][1] == pytest.approx(3.) @@ -528,7 +532,7 @@ def test_expansion_with_blocker(): def blocker(simplex): try: - # Block all simplices that countains vertex 6 + # Block all simplices that contain vertex 6 simplex.index(6) print(simplex, ' is blocked') return True diff --git a/src/python/test/test_subsampling.py b/src/python/test/test_subsampling.py index 4019852e..3431f372 100755 --- a/src/python/test/test_subsampling.py +++ b/src/python/test/test_subsampling.py @@ -91,7 +91,7 @@ def test_simple_choose_n_farthest_points_randomed(): assert gudhi.choose_n_farthest_points(points=[], nb_points=1) == [] assert gudhi.choose_n_farthest_points(points=point_set, nb_points=0) == [] - # Go furter than point set on purpose + # Go further than point set on purpose for iter in range(1, 10): sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=iter) for sub in sub_set: @@ -117,7 +117,7 @@ def test_simple_pick_n_random_points(): assert gudhi.pick_n_random_points(points=[], nb_points=1) == [] assert gudhi.pick_n_random_points(points=point_set, nb_points=0) == [] - # Go furter than point set on purpose + # Go further than point set on purpose for iter in range(1, 10): sub_set = gudhi.pick_n_random_points(points=point_set, nb_points=iter) for sub in sub_set: |