diff options
author | Hind-M <hind.montassif@gmail.com> | 2022-03-02 17:58:39 +0100 |
---|---|---|
committer | Hind-M <hind.montassif@gmail.com> | 2022-03-02 17:58:39 +0100 |
commit | 58e2f677081b4e9f21c47d6286b329218aa825d6 (patch) | |
tree | 9b6075210f624093785472f0bd9ce7c469360df3 /src/python/gudhi/datasets/remote.py | |
parent | 5c0c731fdd2bc41c2a4833be1612dca5a082c337 (diff) |
Remove file when given checksum does not match
Add more details to doc
Remove default dirname value in _fetch_remote
Add points/ subfolder in fetching functions
Diffstat (limited to 'src/python/gudhi/datasets/remote.py')
-rw-r--r-- | src/python/gudhi/datasets/remote.py | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py index 618fa80e..8b3baef4 100644 --- a/src/python/gudhi/datasets/remote.py +++ b/src/python/gudhi/datasets/remote.py @@ -8,7 +8,7 @@ # - YYYY/MM Author: Description of the modification from os.path import join, exists, expanduser -from os import makedirs +from os import makedirs, remove from urllib.request import urlretrieve import hashlib @@ -77,7 +77,7 @@ def _checksum_sha256(file_path): sha256_hash.update(buffer) return sha256_hash.hexdigest() -def _fetch_remote(url, filename, dirname = "gudhi_data", file_checksum = None, accept_license = False): +def _fetch_remote(url, filename, dirname, file_checksum = None, accept_license = False): """ Fetch the wanted dataset from the given url and save it in file_path. @@ -88,7 +88,7 @@ def _fetch_remote(url, filename, dirname = "gudhi_data", file_checksum = None, a filename : string The name to give to downloaded file. dirname : string - The directory to save the file to. Default is "gudhi_data". + The directory to save the file to. file_checksum : string The file checksum using sha256 to check against the one computed on the downloaded file. Default is 'None', which means the checksum is not checked. @@ -115,6 +115,8 @@ def _fetch_remote(url, filename, dirname = "gudhi_data", file_checksum = None, a if file_checksum is not None: checksum = _checksum_sha256(file_path) if file_checksum != checksum: + # Remove file and raise error + remove(file_path) raise IOError("{} has a SHA256 checksum : {}, " "different from expected : {}." "The file may be corrupted or the given url may be wrong !".format(file_path, checksum, file_checksum)) @@ -148,17 +150,17 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None): filename : string The name to give to downloaded file. Default is "spiral_2d.npy". dirname : string - The directory to save the file to. Default is None, meaning that the data home will be set to "~/gudhi_data/spiral_2d". + The directory to save the file to. Default is None, meaning that the downloaded file will be put in "~/gudhi_data/points/spiral_2d". Returns ------- - points: array - Array of points. + points: numpy array + Array of shape (114562, 2). """ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy" file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf' - archive_path, dirname = _get_archive_and_dir(dirname, filename, "spiral_2d") + archive_path, dirname = _get_archive_and_dir(dirname, filename, "points/spiral_2d") if not exists(archive_path): file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum) @@ -170,21 +172,22 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None): def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False): """ Fetch Stanford bunny dataset remotely and its LICENSE file. + This dataset contains 35947 vertices. Parameters ---------- filename : string The name to give to downloaded file. Default is "bunny.npy". dirname : string - The directory to save the file to. Default is None, meaning that the data home will be set to "~/gudhi_data/bunny". + The directory to save the file to. Default is None, meaning that the downloaded files will be put in "~/gudhi_data/points/bunny". accept_license : boolean Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms. Default is False. Returns ------- - points: array - Array of points. + points: numpy array + Array of shape (35947, 3). """ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy" @@ -192,7 +195,7 @@ def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False): license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE" license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a' - archive_path, dirname = _get_archive_and_dir(dirname, filename, "bunny") + archive_path, dirname = _get_archive_and_dir(dirname, filename, "points/bunny") if not exists(archive_path): license_path = _fetch_remote(license_url, "LICENSE", dirname, license_checksum) |