summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2022-03-02 10:26:52 +0100
committerHind-M <hind.montassif@gmail.com>2022-03-02 10:26:52 +0100
commit5c0c731fdd2bc41c2a4833be1612dca5a082c337 (patch)
tree117d50b793c7906c76024de269b3f1fb052d8a3f /src/python
parente964ec32247ce02fb12939cfcddaeabc04639869 (diff)
Modifications following PR review
Diffstat (limited to 'src/python')
-rw-r--r--src/python/gudhi/datasets/remote.py60
-rw-r--r--src/python/test/test_remote_datasets.py38
2 files changed, 51 insertions, 47 deletions
diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py
index 3d6c01b0..618fa80e 100644
--- a/src/python/gudhi/datasets/remote.py
+++ b/src/python/gudhi/datasets/remote.py
@@ -20,14 +20,14 @@ def get_data_home(data_home = None):
"""
Return the path of the remote datasets directory.
This folder is used to store remotely fetched datasets.
- By default the datasets directory is set to a folder named 'remote_datasets' in the user home folder.
+ By default the datasets directory is set to a folder named 'gudhi_data' in the user home folder.
Alternatively, it can be set by giving an explicit folder path. The '~' symbol is expanded to the user home folder.
If the folder does not already exist, it is automatically created.
Parameters
----------
data_home : string
- The path to remote datasets directory. Default is `None`, meaning that the data home directory will be set to "~/remote_datasets".
+ The path to remote datasets directory. Default is `None`, meaning that the data home directory will be set to "~/gudhi_data".
Returns
-------
@@ -35,7 +35,7 @@ def get_data_home(data_home = None):
The path to remote datasets directory.
"""
if data_home is None:
- data_home = join("~", "remote_datasets")
+ data_home = join("~", "gudhi_data")
data_home = expanduser(data_home)
makedirs(data_home, exist_ok=True)
return data_home
@@ -43,12 +43,12 @@ def get_data_home(data_home = None):
def clear_data_home(data_home = None):
"""
- Delete all the content of the data home cache.
+ Delete the data home cache directory and all its content.
Parameters
----------
data_home : string, default is None.
- The path to remote datasets directory. If `None`, the default directory to be removed is set to "~/remote_datasets".
+ The path to remote datasets directory. If `None`, the default directory to be removed is set to "~/gudhi_data".
"""
data_home = get_data_home(data_home)
shutil.rmtree(data_home)
@@ -77,7 +77,7 @@ def _checksum_sha256(file_path):
sha256_hash.update(buffer)
return sha256_hash.hexdigest()
-def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = None, accept_license = False):
+def _fetch_remote(url, filename, dirname = "gudhi_data", file_checksum = None, accept_license = False):
"""
Fetch the wanted dataset from the given url and save it in file_path.
@@ -88,10 +88,10 @@ def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = No
filename : string
The name to give to downloaded file.
dirname : string
- The directory to save the file to. Default is "remote_datasets".
+ The directory to save the file to. Default is "gudhi_data".
file_checksum : string
The file checksum using sha256 to check against the one computed on the downloaded file.
- Default is 'None'.
+ Default is 'None', which means the checksum is not checked.
accept_license : boolean
Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms.
Default is False.
@@ -100,6 +100,11 @@ def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = No
-------
file_path: string
Full path of the created file.
+
+ Raises
+ ------
+ IOError
+ If the computed SHA256 checksum of file does not match the one given by the user.
"""
file_path = join(dirname, filename)
@@ -123,32 +128,37 @@ def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = No
return file_path
+def _get_archive_and_dir(dirname, filename, label):
+ if dirname is None:
+ dirname = join(get_data_home(dirname), label)
+ makedirs(dirname, exist_ok=True)
+ else:
+ dirname = get_data_home(dirname)
+
+ archive_path = join(dirname, filename)
+
+ return archive_path, dirname
+
def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None):
"""
- Fetch "spiral_2d.npy" remotely.
+ Fetch spiral_2d dataset remotely.
Parameters
----------
filename : string
The name to give to downloaded file. Default is "spiral_2d.npy".
dirname : string
- The directory to save the file to. Default is None, meaning that the data home will be set to "~/remote_datasets/spiral_2d".
+ The directory to save the file to. Default is None, meaning that the data home will be set to "~/gudhi_data/spiral_2d".
Returns
-------
points: array
- Array of points stored in "spiral_2d.npy".
+ Array of points.
"""
file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy"
file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf'
- if dirname is None:
- dirname = join(get_data_home(dirname), "spiral_2d")
- makedirs(dirname, exist_ok=True)
- else:
- dirname = get_data_home(dirname)
-
- archive_path = join(dirname, filename)
+ archive_path, dirname = _get_archive_and_dir(dirname, filename, "spiral_2d")
if not exists(archive_path):
file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum)
@@ -159,14 +169,14 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None):
def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False):
"""
- Fetch "bunny.npy" remotely and its LICENSE file.
+ Fetch Stanford bunny dataset remotely and its LICENSE file.
Parameters
----------
filename : string
The name to give to downloaded file. Default is "bunny.npy".
dirname : string
- The directory to save the file to. Default is None, meaning that the data home will be set to "~/remote_datasets/bunny".
+ The directory to save the file to. Default is None, meaning that the data home will be set to "~/gudhi_data/bunny".
accept_license : boolean
Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms.
Default is False.
@@ -174,7 +184,7 @@ def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False):
Returns
-------
points: array
- Array of points stored in "bunny.npy".
+ Array of points.
"""
file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy"
@@ -182,13 +192,7 @@ def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False):
license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE"
license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a'
- if dirname is None:
- dirname = join(get_data_home(dirname), "bunny")
- makedirs(dirname, exist_ok=True)
- else:
- dirname = get_data_home(dirname)
-
- archive_path = join(dirname, filename)
+ archive_path, dirname = _get_archive_and_dir(dirname, filename, "bunny")
if not exists(archive_path):
license_path = _fetch_remote(license_url, "LICENSE", dirname, license_checksum)
diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py
index cb53cb85..c44ac22b 100644
--- a/src/python/test/test_remote_datasets.py
+++ b/src/python/test/test_remote_datasets.py
@@ -22,7 +22,7 @@ def _check_dir_file_names(path_file_dw, filename, dirname):
assert isfile(path_file_dw)
names_dw = re.split(r' |/|\\', path_file_dw)
- # Case where inner directories are created in "remote_datasets/"; e.g: "remote_datasets/bunny"
+ # Case where inner directories are created in "test_gudhi_data/"; e.g: "test_gudhi_data/bunny"
if len(names_dw) >= 3:
for i in range(len(names_dw)-1):
assert re.split(r' |/|\\', dirname)[i] == names_dw[i]
@@ -31,7 +31,7 @@ def _check_dir_file_names(path_file_dw, filename, dirname):
assert dirname == names_dw[0]
assert filename == names_dw[1]
-def _check_fetch_output(url, filename, dirname = "remote_datasets", file_checksum = None):
+def _check_fetch_output(url, filename, dirname = "test_gudhi_data", file_checksum = None):
makedirs(dirname, exist_ok=True)
path_file_dw = remote._fetch_remote(url, filename, dirname, file_checksum)
_check_dir_file_names(path_file_dw, filename, dirname)
@@ -41,9 +41,9 @@ def _get_bunny_license_print(accept_license = False):
# Redirect stdout
sys.stdout = capturedOutput
- makedirs("remote_datasets/bunny", exist_ok=True)
+ makedirs("test_gudhi_data/bunny", exist_ok=True)
- remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy", "bunny.npy", "remote_datasets/bunny",
+ remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy", "bunny.npy", "test_gudhi_data/bunny",
'13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b', accept_license)
# Reset redirect
sys.stdout = sys.__stdout__
@@ -68,19 +68,21 @@ def test_fetch_remote_datasets():
# Test printing existing LICENSE file when fetching bunny.npy with accept_license = False (default)
# Fetch LICENSE file
- makedirs("remote_datasets/bunny", exist_ok=True)
- remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE", "LICENSE", "remote_datasets/bunny",
+ makedirs("test_gudhi_data/bunny", exist_ok=True)
+ remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE", "LICENSE", "test_gudhi_data/bunny",
'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a')
- with open("remote_datasets/bunny/LICENSE") as f:
+ with open("test_gudhi_data/bunny/LICENSE") as f:
assert f.read().rstrip("\n") == _get_bunny_license_print().getvalue().rstrip("\n")
# Test not printing bunny.npy LICENSE when accept_license = True
assert "" == _get_bunny_license_print(accept_license = True).getvalue()
- # Remove "remote_datasets" directory and all its content
- shutil.rmtree("remote_datasets")
+ # Remove "test_gudhi_data" directory and all its content
+ shutil.rmtree("test_gudhi_data")
def test_fetch_remote_datasets_wrapped():
+ # Check if gudhi_data default dir exists already
+ to_be_removed = not isdir(expanduser("~/gudhi_data"))
# Test fetch_spiral_2d and fetch_bunny wrapping functions (twice, to test case of already fetched files)
for i in range(2):
spiral_2d_arr = remote.fetch_spiral_2d()
@@ -90,29 +92,27 @@ def test_fetch_remote_datasets_wrapped():
assert bunny_arr.shape == (35947, 3)
# Check that default dir was created
- assert isdir(expanduser("~/remote_datasets"))
+ assert isdir(expanduser("~/gudhi_data"))
# Test fetch_spiral_2d and fetch_bunny wrapping functions with data directory different from default
- spiral_2d_arr = remote.fetch_spiral_2d(dirname = "~/another_fetch_folder")
+ spiral_2d_arr = remote.fetch_spiral_2d(dirname = "./another_fetch_folder_for_test")
assert spiral_2d_arr.shape == (114562, 2)
- bunny_arr = remote.fetch_bunny(dirname = "~/another_fetch_folder")
+ bunny_arr = remote.fetch_bunny(dirname = "./another_fetch_folder_for_test")
assert bunny_arr.shape == (35947, 3)
- assert isdir(expanduser("~/another_fetch_folder"))
+ assert isdir(expanduser("./another_fetch_folder_for_test"))
# Remove test folders
del spiral_2d_arr
del bunny_arr
- shutil.rmtree(expanduser("~/remote_datasets"))
- shutil.rmtree(expanduser("~/another_fetch_folder"))
-
- assert not isdir(expanduser("~/remote_datasets"))
- assert not isdir(expanduser("~/another_fetch_folder"))
+ if to_be_removed:
+ shutil.rmtree(expanduser("~/gudhi_data"))
+ shutil.rmtree(expanduser("./another_fetch_folder_for_test"))
def test_data_home():
# Test get_data_home and clear_data_home on new empty folder
- empty_data_home = remote.get_data_home(data_home="empty_folder")
+ empty_data_home = remote.get_data_home(data_home="empty_folder_for_test")
assert isdir(empty_data_home)
remote.clear_data_home(data_home=empty_data_home)