summaryrefslogtreecommitdiff
path: root/src/python/gudhi/datasets
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2022-02-04 15:39:51 +0100
committerHind-M <hind.montassif@gmail.com>2022-02-04 15:39:51 +0100
commita13282e4da9910a5d2bdadf97040095ae5b7880a (patch)
treefe6f0064dfc65c3101f02f530c10d37472795918 /src/python/gudhi/datasets
parent6109fd920ba477f89e83fea3df9803232c169463 (diff)
Store fetched datasets in user directory by default
Diffstat (limited to 'src/python/gudhi/datasets')
-rw-r--r--src/python/gudhi/datasets/remote.py68
1 files changed, 55 insertions, 13 deletions
diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py
index 3498a645..3d6c01b0 100644
--- a/src/python/gudhi/datasets/remote.py
+++ b/src/python/gudhi/datasets/remote.py
@@ -7,14 +7,52 @@
# Modification(s):
# - YYYY/MM Author: Description of the modification
-from os.path import join, exists
+from os.path import join, exists, expanduser
from os import makedirs
from urllib.request import urlretrieve
import hashlib
+import shutil
import numpy as np
+def get_data_home(data_home = None):
+ """
+ Return the path of the remote datasets directory.
+ This folder is used to store remotely fetched datasets.
+ By default the datasets directory is set to a folder named 'remote_datasets' in the user home folder.
+ Alternatively, it can be set by giving an explicit folder path. The '~' symbol is expanded to the user home folder.
+ If the folder does not already exist, it is automatically created.
+
+ Parameters
+ ----------
+ data_home : string
+ The path to remote datasets directory. Default is `None`, meaning that the data home directory will be set to "~/remote_datasets".
+
+ Returns
+ -------
+ data_home: string
+ The path to remote datasets directory.
+ """
+ if data_home is None:
+ data_home = join("~", "remote_datasets")
+ data_home = expanduser(data_home)
+ makedirs(data_home, exist_ok=True)
+ return data_home
+
+
+def clear_data_home(data_home = None):
+ """
+ Delete all the content of the data home cache.
+
+ Parameters
+ ----------
+ data_home : string, default is None.
+ The path to remote datasets directory. If `None`, the default directory to be removed is set to "~/remote_datasets".
+ """
+ data_home = get_data_home(data_home)
+ shutil.rmtree(data_home)
+
def _checksum_sha256(file_path):
"""
Compute the file checksum using sha256.
@@ -85,7 +123,7 @@ def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = No
return file_path
-def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = "remote_datasets/spiral_2d"):
+def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None):
"""
Fetch "spiral_2d.npy" remotely.
@@ -94,7 +132,7 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = "remote_datasets/spira
filename : string
The name to give to downloaded file. Default is "spiral_2d.npy".
dirname : string
- The directory to save the file to. Default is "remote_datasets/spiral_2d".
+ The directory to save the file to. Default is None, meaning that the data home will be set to "~/remote_datasets/spiral_2d".
Returns
-------
@@ -104,20 +142,22 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = "remote_datasets/spira
file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy"
file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf'
+ if dirname is None:
+ dirname = join(get_data_home(dirname), "spiral_2d")
+ makedirs(dirname, exist_ok=True)
+ else:
+ dirname = get_data_home(dirname)
+
archive_path = join(dirname, filename)
if not exists(archive_path):
- # Create directory if not existing
- if not exists(dirname):
- makedirs(dirname)
-
file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum)
return np.load(file_path_pkl, mmap_mode='r')
else:
return np.load(archive_path, mmap_mode='r')
-def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accept_license = False):
+def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False):
"""
Fetch "bunny.npy" remotely and its LICENSE file.
@@ -126,7 +166,7 @@ def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accep
filename : string
The name to give to downloaded file. Default is "bunny.npy".
dirname : string
- The directory to save the file to. Default is "remote_datasets/bunny".
+ The directory to save the file to. Default is None, meaning that the data home will be set to "~/remote_datasets/bunny".
accept_license : boolean
Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms.
Default is False.
@@ -142,13 +182,15 @@ def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accep
license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE"
license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a'
+ if dirname is None:
+ dirname = join(get_data_home(dirname), "bunny")
+ makedirs(dirname, exist_ok=True)
+ else:
+ dirname = get_data_home(dirname)
+
archive_path = join(dirname, filename)
if not exists(archive_path):
- # Create directory if not existing
- if not exists(dirname):
- makedirs(dirname)
-
license_path = _fetch_remote(license_url, "LICENSE", dirname, license_checksum)
file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum, accept_license)