From a13282e4da9910a5d2bdadf97040095ae5b7880a Mon Sep 17 00:00:00 2001 From: Hind-M Date: Fri, 4 Feb 2022 15:39:51 +0100 Subject: Store fetched datasets in user directory by default --- src/python/gudhi/datasets/remote.py | 68 ++++++++++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 13 deletions(-) (limited to 'src/python/gudhi/datasets/remote.py') diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py index 3498a645..3d6c01b0 100644 --- a/src/python/gudhi/datasets/remote.py +++ b/src/python/gudhi/datasets/remote.py @@ -7,14 +7,52 @@ # Modification(s): # - YYYY/MM Author: Description of the modification -from os.path import join, exists +from os.path import join, exists, expanduser from os import makedirs from urllib.request import urlretrieve import hashlib +import shutil import numpy as np +def get_data_home(data_home = None): + """ + Return the path of the remote datasets directory. + This folder is used to store remotely fetched datasets. + By default the datasets directory is set to a folder named 'remote_datasets' in the user home folder. + Alternatively, it can be set by giving an explicit folder path. The '~' symbol is expanded to the user home folder. + If the folder does not already exist, it is automatically created. + + Parameters + ---------- + data_home : string + The path to remote datasets directory. Default is `None`, meaning that the data home directory will be set to "~/remote_datasets". + + Returns + ------- + data_home: string + The path to remote datasets directory. + """ + if data_home is None: + data_home = join("~", "remote_datasets") + data_home = expanduser(data_home) + makedirs(data_home, exist_ok=True) + return data_home + + +def clear_data_home(data_home = None): + """ + Delete all the content of the data home cache. + + Parameters + ---------- + data_home : string, default is None. + The path to remote datasets directory. If `None`, the default directory to be removed is set to "~/remote_datasets". + """ + data_home = get_data_home(data_home) + shutil.rmtree(data_home) + def _checksum_sha256(file_path): """ Compute the file checksum using sha256. @@ -85,7 +123,7 @@ def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = No return file_path -def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = "remote_datasets/spiral_2d"): +def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None): """ Fetch "spiral_2d.npy" remotely. @@ -94,7 +132,7 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = "remote_datasets/spira filename : string The name to give to downloaded file. Default is "spiral_2d.npy". dirname : string - The directory to save the file to. Default is "remote_datasets/spiral_2d". + The directory to save the file to. Default is None, meaning that the data home will be set to "~/remote_datasets/spiral_2d". Returns ------- @@ -104,20 +142,22 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = "remote_datasets/spira file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy" file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf' + if dirname is None: + dirname = join(get_data_home(dirname), "spiral_2d") + makedirs(dirname, exist_ok=True) + else: + dirname = get_data_home(dirname) + archive_path = join(dirname, filename) if not exists(archive_path): - # Create directory if not existing - if not exists(dirname): - makedirs(dirname) - file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum) return np.load(file_path_pkl, mmap_mode='r') else: return np.load(archive_path, mmap_mode='r') -def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accept_license = False): +def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False): """ Fetch "bunny.npy" remotely and its LICENSE file. @@ -126,7 +166,7 @@ def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accep filename : string The name to give to downloaded file. Default is "bunny.npy". dirname : string - The directory to save the file to. Default is "remote_datasets/bunny". + The directory to save the file to. Default is None, meaning that the data home will be set to "~/remote_datasets/bunny". accept_license : boolean Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms. Default is False. @@ -142,13 +182,15 @@ def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accep license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE" license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a' + if dirname is None: + dirname = join(get_data_home(dirname), "bunny") + makedirs(dirname, exist_ok=True) + else: + dirname = get_data_home(dirname) + archive_path = join(dirname, filename) if not exists(archive_path): - # Create directory if not existing - if not exists(dirname): - makedirs(dirname) - license_path = _fetch_remote(license_url, "LICENSE", dirname, license_checksum) file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum, accept_license) -- cgit v1.2.3