summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2022-05-04 15:27:34 +0200
committerHind-M <hind.montassif@gmail.com>2022-05-04 15:27:34 +0200
commitef8284cce27a8f11947e7f076034aa2fd8b5a395 (patch)
tree08799a7a10f89cf0d186a4de315fb7816bcfc7c2 /src/python
parent996c0669ac92ec98576c7a0dae2358dc0d4bc2c9 (diff)
Ask for file_path as parameter of remote fetching functions instead of both dirname and filename
Modify remote fetching test
Diffstat (limited to 'src/python')
-rw-r--r--src/python/gudhi/datasets/remote.py106
-rw-r--r--src/python/test/test_remote_datasets.py94
2 files changed, 83 insertions, 117 deletions
diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py
index 8b3baef4..5b535911 100644
--- a/src/python/gudhi/datasets/remote.py
+++ b/src/python/gudhi/datasets/remote.py
@@ -7,7 +7,7 @@
# Modification(s):
# - YYYY/MM Author: Description of the modification
-from os.path import join, exists, expanduser
+from os.path import join, split, exists, expanduser
from os import makedirs, remove
from urllib.request import urlretrieve
@@ -60,7 +60,7 @@ def _checksum_sha256(file_path):
Parameters
----------
file_path: string
- Full path of the created file.
+ Full path of the created file including filename.
Returns
-------
@@ -77,7 +77,7 @@ def _checksum_sha256(file_path):
sha256_hash.update(buffer)
return sha256_hash.hexdigest()
-def _fetch_remote(url, filename, dirname, file_checksum = None, accept_license = False):
+def _fetch_remote(url, file_path, file_checksum = None):
"""
Fetch the wanted dataset from the given url and save it in file_path.
@@ -85,21 +85,11 @@ def _fetch_remote(url, filename, dirname, file_checksum = None, accept_license =
----------
url : string
The url to fetch the dataset from.
- filename : string
- The name to give to downloaded file.
- dirname : string
- The directory to save the file to.
+ file_path : string
+ Full path of the downloaded file including filename.
file_checksum : string
The file checksum using sha256 to check against the one computed on the downloaded file.
Default is 'None', which means the checksum is not checked.
- accept_license : boolean
- Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms.
- Default is False.
-
- Returns
- -------
- file_path: string
- Full path of the created file.
Raises
------
@@ -107,8 +97,6 @@ def _fetch_remote(url, filename, dirname, file_checksum = None, accept_license =
If the computed SHA256 checksum of file does not match the one given by the user.
"""
- file_path = join(dirname, filename)
-
# Get the file
urlretrieve(url, file_path)
@@ -121,36 +109,41 @@ def _fetch_remote(url, filename, dirname, file_checksum = None, accept_license =
"different from expected : {}."
"The file may be corrupted or the given url may be wrong !".format(file_path, checksum, file_checksum))
- # Print license terms unless accept_license is set to True
- if not accept_license:
- license_file = join(dirname, "LICENSE")
- if exists(license_file) and (file_path != license_file):
- with open(license_file, 'r') as f:
- print(f.read())
+def _get_archive_path(file_path, label):
+ """
+ Get archive path based on file_path given by user and label.
- return file_path
+ Parameters
+ ----------
+ file_path: string
+ Full path of the file to get including filename, or None.
+ label: string
+ Label used along with 'data_home' to get archive path, in case 'file_path' is None.
-def _get_archive_and_dir(dirname, filename, label):
- if dirname is None:
- dirname = join(get_data_home(dirname), label)
+ Returns
+ -------
+ Full path of archive including filename.
+ """
+ if file_path is None:
+ archive_path = join(get_data_home(), label)
+ dirname = split(archive_path)[0]
makedirs(dirname, exist_ok=True)
else:
- dirname = get_data_home(dirname)
-
- archive_path = join(dirname, filename)
+ archive_path = file_path
+ dirname = split(archive_path)[0]
+ makedirs(dirname, exist_ok=True)
- return archive_path, dirname
+ return archive_path
-def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None):
+def fetch_spiral_2d(file_path = None):
"""
Fetch spiral_2d dataset remotely.
Parameters
----------
- filename : string
- The name to give to downloaded file. Default is "spiral_2d.npy".
- dirname : string
- The directory to save the file to. Default is None, meaning that the downloaded file will be put in "~/gudhi_data/points/spiral_2d".
+ file_path : string
+ Full path of the downloaded file including filename.
+ Default is None, meaning that it's set to "data_home/points/spiral_2d/spiral_2d.npy".
Returns
-------
@@ -158,28 +151,25 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None):
Array of shape (114562, 2).
"""
file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy"
- file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf'
+ file_checksum = '2226024da76c073dd2f24b884baefbfd14928b52296df41ad2d9b9dc170f2401'
- archive_path, dirname = _get_archive_and_dir(dirname, filename, "points/spiral_2d")
+ archive_path = _get_archive_path(file_path, "points/spiral_2d/spiral_2d.npy")
if not exists(archive_path):
- file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum)
+ _fetch_remote(file_url, archive_path, file_checksum)
- return np.load(file_path_pkl, mmap_mode='r')
- else:
- return np.load(archive_path, mmap_mode='r')
+ return np.load(archive_path, mmap_mode='r')
-def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False):
+def fetch_bunny(file_path = None, accept_license = False):
"""
Fetch Stanford bunny dataset remotely and its LICENSE file.
This dataset contains 35947 vertices.
Parameters
----------
- filename : string
- The name to give to downloaded file. Default is "bunny.npy".
- dirname : string
- The directory to save the file to. Default is None, meaning that the downloaded files will be put in "~/gudhi_data/points/bunny".
+ file_path : string
+ Full path of the downloaded file including filename.
+ Default is None, meaning that it's set to "data_home/points/bunny/bunny.npy".
accept_license : boolean
Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms.
Default is False.
@@ -191,16 +181,20 @@ def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False):
"""
file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy"
- file_checksum = '13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b'
- license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE"
+ file_checksum = 'f382482fd89df8d6444152dc8fd454444fe597581b193fd139725a85af4a6c6e'
+ license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.LICENSE"
license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a'
- archive_path, dirname = _get_archive_and_dir(dirname, filename, "points/bunny")
+ archive_path = _get_archive_path(file_path, "points/bunny/bunny.npy")
if not exists(archive_path):
- license_path = _fetch_remote(license_url, "LICENSE", dirname, license_checksum)
- file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum, accept_license)
-
- return np.load(file_path_pkl, mmap_mode='r')
- else:
- return np.load(archive_path, mmap_mode='r')
+ _fetch_remote(file_url, archive_path, file_checksum)
+ license_path = join(split(archive_path)[0], "bunny.LICENSE")
+ _fetch_remote(license_url, license_path, license_checksum)
+ # Print license terms unless accept_license is set to True
+ if not accept_license:
+ if exists(license_path):
+ with open(license_path, 'r') as f:
+ print(f.read())
+
+ return np.load(archive_path, mmap_mode='r')
diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py
index c44ac22b..5d0d397d 100644
--- a/src/python/test/test_remote_datasets.py
+++ b/src/python/test/test_remote_datasets.py
@@ -9,76 +9,48 @@
from gudhi.datasets import remote
-import re
import shutil
import io
import sys
import pytest
-from os.path import isfile, isdir, expanduser
-from os import makedirs
+from os.path import isdir, expanduser, exists
+from os import remove
-def _check_dir_file_names(path_file_dw, filename, dirname):
- assert isfile(path_file_dw)
+def test_data_home():
+ # Test get_data_home and clear_data_home on new empty folder
+ empty_data_home = remote.get_data_home(data_home="empty_folder_for_test")
+ assert isdir(empty_data_home)
- names_dw = re.split(r' |/|\\', path_file_dw)
- # Case where inner directories are created in "test_gudhi_data/"; e.g: "test_gudhi_data/bunny"
- if len(names_dw) >= 3:
- for i in range(len(names_dw)-1):
- assert re.split(r' |/|\\', dirname)[i] == names_dw[i]
- assert filename == names_dw[i+1]
- else:
- assert dirname == names_dw[0]
- assert filename == names_dw[1]
+ remote.clear_data_home(data_home=empty_data_home)
+ assert not isdir(empty_data_home)
-def _check_fetch_output(url, filename, dirname = "test_gudhi_data", file_checksum = None):
- makedirs(dirname, exist_ok=True)
- path_file_dw = remote._fetch_remote(url, filename, dirname, file_checksum)
- _check_dir_file_names(path_file_dw, filename, dirname)
+def test_fetch_remote():
+ # Test fetch with a wrong checksum
+ with pytest.raises(OSError):
+ remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "tmp_spiral_2d.npy", file_checksum = 'XXXXXXXXXX')
+ assert not exists("tmp_spiral_2d.npy")
def _get_bunny_license_print(accept_license = False):
capturedOutput = io.StringIO()
# Redirect stdout
sys.stdout = capturedOutput
- makedirs("test_gudhi_data/bunny", exist_ok=True)
+ bunny_arr = remote.fetch_bunny("./tmp_for_test/bunny.npy", accept_license)
+ assert bunny_arr.shape == (35947, 3)
+ remove("./tmp_for_test/bunny.npy")
- remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy", "bunny.npy", "test_gudhi_data/bunny",
- '13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b', accept_license)
# Reset redirect
sys.stdout = sys.__stdout__
return capturedOutput
-def test_fetch_remote_datasets():
- # Test fetch with a wrong checksum
- with pytest.raises(OSError):
- _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "spiral_2d.npy", file_checksum = 'XXXXXXXXXX')
-
- # Test files download from given urls with checksums provided
- _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "spiral_2d.npy",
- file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf')
-
- _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/sphere3D_pts_on_grid.off", "sphere3D_pts_on_grid.off",
- file_checksum = '32f96d2cafb1177f0dd5e0a019b6ff5658e14a619a7815ae55ad0fc5e8bd3f88')
-
- # Test files download from given urls without checksums
- _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "spiral_2d.npy")
-
- _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/sphere3D_pts_on_grid.off", "sphere3D_pts_on_grid.off")
-
- # Test printing existing LICENSE file when fetching bunny.npy with accept_license = False (default)
- # Fetch LICENSE file
- makedirs("test_gudhi_data/bunny", exist_ok=True)
- remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE", "LICENSE", "test_gudhi_data/bunny",
- 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a')
- with open("test_gudhi_data/bunny/LICENSE") as f:
- assert f.read().rstrip("\n") == _get_bunny_license_print().getvalue().rstrip("\n")
-
+def test_print_bunny_license():
# Test not printing bunny.npy LICENSE when accept_license = True
assert "" == _get_bunny_license_print(accept_license = True).getvalue()
-
- # Remove "test_gudhi_data" directory and all its content
- shutil.rmtree("test_gudhi_data")
+ # Test printing bunny.LICENSE file when fetching bunny.npy with accept_license = False (default)
+ with open("./tmp_for_test/bunny.LICENSE") as f:
+ assert f.read().rstrip("\n") == _get_bunny_license_print().getvalue().rstrip("\n")
+ shutil.rmtree("./tmp_for_test")
def test_fetch_remote_datasets_wrapped():
# Check if gudhi_data default dir exists already
@@ -93,27 +65,27 @@ def test_fetch_remote_datasets_wrapped():
# Check that default dir was created
assert isdir(expanduser("~/gudhi_data"))
+ # Check downloaded files
+ assert exists(expanduser("~/gudhi_data/points/spiral_2d/spiral_2d.npy"))
+ assert exists(expanduser("~/gudhi_data/points/bunny/bunny.npy"))
+ assert exists(expanduser("~/gudhi_data/points/bunny/bunny.LICENSE"))
# Test fetch_spiral_2d and fetch_bunny wrapping functions with data directory different from default
- spiral_2d_arr = remote.fetch_spiral_2d(dirname = "./another_fetch_folder_for_test")
+ spiral_2d_arr = remote.fetch_spiral_2d("./another_fetch_folder_for_test/spiral_2d.npy")
assert spiral_2d_arr.shape == (114562, 2)
- bunny_arr = remote.fetch_bunny(dirname = "./another_fetch_folder_for_test")
+ bunny_arr = remote.fetch_bunny("./another_fetch_folder_for_test/bunny.npy")
assert bunny_arr.shape == (35947, 3)
- assert isdir(expanduser("./another_fetch_folder_for_test"))
+ assert isdir("./another_fetch_folder_for_test")
+ # Check downloaded files
+ assert exists("./another_fetch_folder_for_test/spiral_2d.npy")
+ assert exists("./another_fetch_folder_for_test/bunny.npy")
+ assert exists("./another_fetch_folder_for_test/bunny.LICENSE")
# Remove test folders
del spiral_2d_arr
del bunny_arr
if to_be_removed:
shutil.rmtree(expanduser("~/gudhi_data"))
- shutil.rmtree(expanduser("./another_fetch_folder_for_test"))
-
-def test_data_home():
- # Test get_data_home and clear_data_home on new empty folder
- empty_data_home = remote.get_data_home(data_home="empty_folder_for_test")
- assert isdir(empty_data_home)
-
- remote.clear_data_home(data_home=empty_data_home)
- assert not isdir(empty_data_home)
+ shutil.rmtree("./another_fetch_folder_for_test")