From 5c0c731fdd2bc41c2a4833be1612dca5a082c337 Mon Sep 17 00:00:00 2001 From: Hind-M Date: Wed, 2 Mar 2022 10:26:52 +0100 Subject: Modifications following PR review --- src/python/gudhi/datasets/remote.py | 60 ++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 28 deletions(-) (limited to 'src/python/gudhi') diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py index 3d6c01b0..618fa80e 100644 --- a/src/python/gudhi/datasets/remote.py +++ b/src/python/gudhi/datasets/remote.py @@ -20,14 +20,14 @@ def get_data_home(data_home = None): """ Return the path of the remote datasets directory. This folder is used to store remotely fetched datasets. - By default the datasets directory is set to a folder named 'remote_datasets' in the user home folder. + By default the datasets directory is set to a folder named 'gudhi_data' in the user home folder. Alternatively, it can be set by giving an explicit folder path. The '~' symbol is expanded to the user home folder. If the folder does not already exist, it is automatically created. Parameters ---------- data_home : string - The path to remote datasets directory. Default is `None`, meaning that the data home directory will be set to "~/remote_datasets". + The path to remote datasets directory. Default is `None`, meaning that the data home directory will be set to "~/gudhi_data". Returns ------- @@ -35,7 +35,7 @@ def get_data_home(data_home = None): The path to remote datasets directory. """ if data_home is None: - data_home = join("~", "remote_datasets") + data_home = join("~", "gudhi_data") data_home = expanduser(data_home) makedirs(data_home, exist_ok=True) return data_home @@ -43,12 +43,12 @@ def get_data_home(data_home = None): def clear_data_home(data_home = None): """ - Delete all the content of the data home cache. + Delete the data home cache directory and all its content. Parameters ---------- data_home : string, default is None. - The path to remote datasets directory. If `None`, the default directory to be removed is set to "~/remote_datasets". + The path to remote datasets directory. If `None`, the default directory to be removed is set to "~/gudhi_data". """ data_home = get_data_home(data_home) shutil.rmtree(data_home) @@ -77,7 +77,7 @@ def _checksum_sha256(file_path): sha256_hash.update(buffer) return sha256_hash.hexdigest() -def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = None, accept_license = False): +def _fetch_remote(url, filename, dirname = "gudhi_data", file_checksum = None, accept_license = False): """ Fetch the wanted dataset from the given url and save it in file_path. @@ -88,10 +88,10 @@ def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = No filename : string The name to give to downloaded file. dirname : string - The directory to save the file to. Default is "remote_datasets". + The directory to save the file to. Default is "gudhi_data". file_checksum : string The file checksum using sha256 to check against the one computed on the downloaded file. - Default is 'None'. + Default is 'None', which means the checksum is not checked. accept_license : boolean Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms. Default is False. @@ -100,6 +100,11 @@ def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = No ------- file_path: string Full path of the created file. + + Raises + ------ + IOError + If the computed SHA256 checksum of file does not match the one given by the user. """ file_path = join(dirname, filename) @@ -123,32 +128,37 @@ def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = No return file_path +def _get_archive_and_dir(dirname, filename, label): + if dirname is None: + dirname = join(get_data_home(dirname), label) + makedirs(dirname, exist_ok=True) + else: + dirname = get_data_home(dirname) + + archive_path = join(dirname, filename) + + return archive_path, dirname + def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None): """ - Fetch "spiral_2d.npy" remotely. + Fetch spiral_2d dataset remotely. Parameters ---------- filename : string The name to give to downloaded file. Default is "spiral_2d.npy". dirname : string - The directory to save the file to. Default is None, meaning that the data home will be set to "~/remote_datasets/spiral_2d". + The directory to save the file to. Default is None, meaning that the data home will be set to "~/gudhi_data/spiral_2d". Returns ------- points: array - Array of points stored in "spiral_2d.npy". + Array of points. """ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy" file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf' - if dirname is None: - dirname = join(get_data_home(dirname), "spiral_2d") - makedirs(dirname, exist_ok=True) - else: - dirname = get_data_home(dirname) - - archive_path = join(dirname, filename) + archive_path, dirname = _get_archive_and_dir(dirname, filename, "spiral_2d") if not exists(archive_path): file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum) @@ -159,14 +169,14 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None): def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False): """ - Fetch "bunny.npy" remotely and its LICENSE file. + Fetch Stanford bunny dataset remotely and its LICENSE file. Parameters ---------- filename : string The name to give to downloaded file. Default is "bunny.npy". dirname : string - The directory to save the file to. Default is None, meaning that the data home will be set to "~/remote_datasets/bunny". + The directory to save the file to. Default is None, meaning that the data home will be set to "~/gudhi_data/bunny". accept_license : boolean Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms. Default is False. @@ -174,7 +184,7 @@ def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False): Returns ------- points: array - Array of points stored in "bunny.npy". + Array of points. """ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy" @@ -182,13 +192,7 @@ def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False): license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE" license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a' - if dirname is None: - dirname = join(get_data_home(dirname), "bunny") - makedirs(dirname, exist_ok=True) - else: - dirname = get_data_home(dirname) - - archive_path = join(dirname, filename) + archive_path, dirname = _get_archive_and_dir(dirname, filename, "bunny") if not exists(archive_path): license_path = _fetch_remote(license_url, "LICENSE", dirname, license_checksum) -- cgit v1.2.3