summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2022-02-04 15:39:51 +0100
committerHind-M <hind.montassif@gmail.com>2022-02-04 15:39:51 +0100
commita13282e4da9910a5d2bdadf97040095ae5b7880a (patch)
treefe6f0064dfc65c3101f02f530c10d37472795918 /src
parent6109fd920ba477f89e83fea3df9803232c169463 (diff)
Store fetched datasets in user directory by default
Diffstat (limited to 'src')
-rw-r--r--src/python/gudhi/datasets/remote.py68
-rw-r--r--src/python/test/test_remote_datasets.py31
2 files changed, 79 insertions, 20 deletions
diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py
index 3498a645..3d6c01b0 100644
--- a/src/python/gudhi/datasets/remote.py
+++ b/src/python/gudhi/datasets/remote.py
@@ -7,14 +7,52 @@
# Modification(s):
# - YYYY/MM Author: Description of the modification
-from os.path import join, exists
+from os.path import join, exists, expanduser
from os import makedirs
from urllib.request import urlretrieve
import hashlib
+import shutil
import numpy as np
+def get_data_home(data_home = None):
+ """
+ Return the path of the remote datasets directory.
+ This folder is used to store remotely fetched datasets.
+ By default the datasets directory is set to a folder named 'remote_datasets' in the user home folder.
+ Alternatively, it can be set by giving an explicit folder path. The '~' symbol is expanded to the user home folder.
+ If the folder does not already exist, it is automatically created.
+
+ Parameters
+ ----------
+ data_home : string
+ The path to remote datasets directory. Default is `None`, meaning that the data home directory will be set to "~/remote_datasets".
+
+ Returns
+ -------
+ data_home: string
+ The path to remote datasets directory.
+ """
+ if data_home is None:
+ data_home = join("~", "remote_datasets")
+ data_home = expanduser(data_home)
+ makedirs(data_home, exist_ok=True)
+ return data_home
+
+
+def clear_data_home(data_home = None):
+ """
+ Delete all the content of the data home cache.
+
+ Parameters
+ ----------
+ data_home : string, default is None.
+ The path to remote datasets directory. If `None`, the default directory to be removed is set to "~/remote_datasets".
+ """
+ data_home = get_data_home(data_home)
+ shutil.rmtree(data_home)
+
def _checksum_sha256(file_path):
"""
Compute the file checksum using sha256.
@@ -85,7 +123,7 @@ def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = No
return file_path
-def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = "remote_datasets/spiral_2d"):
+def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None):
"""
Fetch "spiral_2d.npy" remotely.
@@ -94,7 +132,7 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = "remote_datasets/spira
filename : string
The name to give to downloaded file. Default is "spiral_2d.npy".
dirname : string
- The directory to save the file to. Default is "remote_datasets/spiral_2d".
+ The directory to save the file to. Default is None, meaning that the data home will be set to "~/remote_datasets/spiral_2d".
Returns
-------
@@ -104,20 +142,22 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = "remote_datasets/spira
file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy"
file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf'
+ if dirname is None:
+ dirname = join(get_data_home(dirname), "spiral_2d")
+ makedirs(dirname, exist_ok=True)
+ else:
+ dirname = get_data_home(dirname)
+
archive_path = join(dirname, filename)
if not exists(archive_path):
- # Create directory if not existing
- if not exists(dirname):
- makedirs(dirname)
-
file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum)
return np.load(file_path_pkl, mmap_mode='r')
else:
return np.load(archive_path, mmap_mode='r')
-def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accept_license = False):
+def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False):
"""
Fetch "bunny.npy" remotely and its LICENSE file.
@@ -126,7 +166,7 @@ def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accep
filename : string
The name to give to downloaded file. Default is "bunny.npy".
dirname : string
- The directory to save the file to. Default is "remote_datasets/bunny".
+ The directory to save the file to. Default is None, meaning that the data home will be set to "~/remote_datasets/bunny".
accept_license : boolean
Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms.
Default is False.
@@ -142,13 +182,15 @@ def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accep
license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE"
license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a'
+ if dirname is None:
+ dirname = join(get_data_home(dirname), "bunny")
+ makedirs(dirname, exist_ok=True)
+ else:
+ dirname = get_data_home(dirname)
+
archive_path = join(dirname, filename)
if not exists(archive_path):
- # Create directory if not existing
- if not exists(dirname):
- makedirs(dirname)
-
license_path = _fetch_remote(license_url, "LICENSE", dirname, license_checksum)
file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum, accept_license)
diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py
index 93a8a982..27eb51b0 100644
--- a/src/python/test/test_remote_datasets.py
+++ b/src/python/test/test_remote_datasets.py
@@ -10,7 +10,7 @@
from gudhi.datasets import remote
import re
-from os.path import isfile, exists
+from os.path import isfile, isdir, expanduser
from os import makedirs
import io
import sys
@@ -30,8 +30,7 @@ def _check_dir_file_names(path_file_dw, filename, dirname):
assert filename == names_dw[1]
def _check_fetch_output(url, filename, dirname = "remote_datasets", file_checksum = None):
- if not exists(dirname):
- makedirs(dirname)
+ makedirs(dirname, exist_ok=True)
path_file_dw = remote._fetch_remote(url, filename, dirname, file_checksum)
_check_dir_file_names(path_file_dw, filename, dirname)
@@ -40,8 +39,7 @@ def _get_bunny_license_print(accept_license = False):
# Redirect stdout
sys.stdout = capturedOutput
- if not exists("remote_datasets/bunny"):
- makedirs("remote_datasets/bunny")
+ makedirs("remote_datasets/bunny", exist_ok=True)
remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy", "bunny.npy", "remote_datasets/bunny",
'13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b', accept_license)
@@ -68,8 +66,7 @@ def test_fetch_remote_datasets():
# Test printing existing LICENSE file when fetching bunny.npy with accept_license = False (default)
# Fetch LICENSE file
- if not exists("remote_datasets/bunny"):
- makedirs("remote_datasets/bunny")
+ makedirs("remote_datasets/bunny", exist_ok=True)
remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE", "LICENSE", "remote_datasets/bunny",
'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a')
with open("remote_datasets/bunny/LICENSE") as f:
@@ -88,3 +85,23 @@ def test_fetch_remote_datasets():
bunny_arr = remote.fetch_bunny()
assert bunny_arr.shape == (35947, 3)
+
+ # Check that default dir was created
+ assert isdir(expanduser("~/remote_datasets")) == True
+
+ # Test clear_data_home
+ clear_data_home()
+ assert isdir(expanduser("~/remote_datasets")) == False
+
+ # Test fetch_spiral_2d and fetch_bunny wrapping functions with data directory different from default
+ spiral_2d_arr = remote.fetch_spiral_2d(dirname = "~/test")
+ assert spiral_2d_arr.shape == (114562, 2)
+
+ bunny_arr = remote.fetch_bunny(dirname = "~/test")
+ assert bunny_arr.shape == (35947, 3)
+
+ assert isdir(expanduser("~/test")) == True
+
+ # Test clear_data_home with data directory different from default
+ clear_data_home("~/test")
+ assert isdir(expanduser("~/test")) == False