From ef8284cce27a8f11947e7f076034aa2fd8b5a395 Mon Sep 17 00:00:00 2001 From: Hind-M Date: Wed, 4 May 2022 15:27:34 +0200 Subject: Ask for file_path as parameter of remote fetching functions instead of both dirname and filename Modify remote fetching test --- src/python/gudhi/datasets/remote.py | 106 +++++++++++++++----------------- src/python/test/test_remote_datasets.py | 94 ++++++++++------------------ 2 files changed, 83 insertions(+), 117 deletions(-) diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py index 8b3baef4..5b535911 100644 --- a/src/python/gudhi/datasets/remote.py +++ b/src/python/gudhi/datasets/remote.py @@ -7,7 +7,7 @@ # Modification(s): # - YYYY/MM Author: Description of the modification -from os.path import join, exists, expanduser +from os.path import join, split, exists, expanduser from os import makedirs, remove from urllib.request import urlretrieve @@ -60,7 +60,7 @@ def _checksum_sha256(file_path): Parameters ---------- file_path: string - Full path of the created file. + Full path of the created file including filename. Returns ------- @@ -77,7 +77,7 @@ def _checksum_sha256(file_path): sha256_hash.update(buffer) return sha256_hash.hexdigest() -def _fetch_remote(url, filename, dirname, file_checksum = None, accept_license = False): +def _fetch_remote(url, file_path, file_checksum = None): """ Fetch the wanted dataset from the given url and save it in file_path. @@ -85,21 +85,11 @@ def _fetch_remote(url, filename, dirname, file_checksum = None, accept_license = ---------- url : string The url to fetch the dataset from. - filename : string - The name to give to downloaded file. - dirname : string - The directory to save the file to. + file_path : string + Full path of the downloaded file including filename. file_checksum : string The file checksum using sha256 to check against the one computed on the downloaded file. Default is 'None', which means the checksum is not checked. - accept_license : boolean - Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms. - Default is False. - - Returns - ------- - file_path: string - Full path of the created file. Raises ------ @@ -107,8 +97,6 @@ def _fetch_remote(url, filename, dirname, file_checksum = None, accept_license = If the computed SHA256 checksum of file does not match the one given by the user. """ - file_path = join(dirname, filename) - # Get the file urlretrieve(url, file_path) @@ -121,36 +109,41 @@ def _fetch_remote(url, filename, dirname, file_checksum = None, accept_license = "different from expected : {}." "The file may be corrupted or the given url may be wrong !".format(file_path, checksum, file_checksum)) - # Print license terms unless accept_license is set to True - if not accept_license: - license_file = join(dirname, "LICENSE") - if exists(license_file) and (file_path != license_file): - with open(license_file, 'r') as f: - print(f.read()) +def _get_archive_path(file_path, label): + """ + Get archive path based on file_path given by user and label. - return file_path + Parameters + ---------- + file_path: string + Full path of the file to get including filename, or None. + label: string + Label used along with 'data_home' to get archive path, in case 'file_path' is None. -def _get_archive_and_dir(dirname, filename, label): - if dirname is None: - dirname = join(get_data_home(dirname), label) + Returns + ------- + Full path of archive including filename. + """ + if file_path is None: + archive_path = join(get_data_home(), label) + dirname = split(archive_path)[0] makedirs(dirname, exist_ok=True) else: - dirname = get_data_home(dirname) - - archive_path = join(dirname, filename) + archive_path = file_path + dirname = split(archive_path)[0] + makedirs(dirname, exist_ok=True) - return archive_path, dirname + return archive_path -def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None): +def fetch_spiral_2d(file_path = None): """ Fetch spiral_2d dataset remotely. Parameters ---------- - filename : string - The name to give to downloaded file. Default is "spiral_2d.npy". - dirname : string - The directory to save the file to. Default is None, meaning that the downloaded file will be put in "~/gudhi_data/points/spiral_2d". + file_path : string + Full path of the downloaded file including filename. + Default is None, meaning that it's set to "data_home/points/spiral_2d/spiral_2d.npy". Returns ------- @@ -158,28 +151,25 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None): Array of shape (114562, 2). """ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy" - file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf' + file_checksum = '2226024da76c073dd2f24b884baefbfd14928b52296df41ad2d9b9dc170f2401' - archive_path, dirname = _get_archive_and_dir(dirname, filename, "points/spiral_2d") + archive_path = _get_archive_path(file_path, "points/spiral_2d/spiral_2d.npy") if not exists(archive_path): - file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum) + _fetch_remote(file_url, archive_path, file_checksum) - return np.load(file_path_pkl, mmap_mode='r') - else: - return np.load(archive_path, mmap_mode='r') + return np.load(archive_path, mmap_mode='r') -def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False): +def fetch_bunny(file_path = None, accept_license = False): """ Fetch Stanford bunny dataset remotely and its LICENSE file. This dataset contains 35947 vertices. Parameters ---------- - filename : string - The name to give to downloaded file. Default is "bunny.npy". - dirname : string - The directory to save the file to. Default is None, meaning that the downloaded files will be put in "~/gudhi_data/points/bunny". + file_path : string + Full path of the downloaded file including filename. + Default is None, meaning that it's set to "data_home/points/bunny/bunny.npy". accept_license : boolean Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms. Default is False. @@ -191,16 +181,20 @@ def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False): """ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy" - file_checksum = '13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b' - license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE" + file_checksum = 'f382482fd89df8d6444152dc8fd454444fe597581b193fd139725a85af4a6c6e' + license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.LICENSE" license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a' - archive_path, dirname = _get_archive_and_dir(dirname, filename, "points/bunny") + archive_path = _get_archive_path(file_path, "points/bunny/bunny.npy") if not exists(archive_path): - license_path = _fetch_remote(license_url, "LICENSE", dirname, license_checksum) - file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum, accept_license) - - return np.load(file_path_pkl, mmap_mode='r') - else: - return np.load(archive_path, mmap_mode='r') + _fetch_remote(file_url, archive_path, file_checksum) + license_path = join(split(archive_path)[0], "bunny.LICENSE") + _fetch_remote(license_url, license_path, license_checksum) + # Print license terms unless accept_license is set to True + if not accept_license: + if exists(license_path): + with open(license_path, 'r') as f: + print(f.read()) + + return np.load(archive_path, mmap_mode='r') diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py index c44ac22b..5d0d397d 100644 --- a/src/python/test/test_remote_datasets.py +++ b/src/python/test/test_remote_datasets.py @@ -9,76 +9,48 @@ from gudhi.datasets import remote -import re import shutil import io import sys import pytest -from os.path import isfile, isdir, expanduser -from os import makedirs +from os.path import isdir, expanduser, exists +from os import remove -def _check_dir_file_names(path_file_dw, filename, dirname): - assert isfile(path_file_dw) +def test_data_home(): + # Test get_data_home and clear_data_home on new empty folder + empty_data_home = remote.get_data_home(data_home="empty_folder_for_test") + assert isdir(empty_data_home) - names_dw = re.split(r' |/|\\', path_file_dw) - # Case where inner directories are created in "test_gudhi_data/"; e.g: "test_gudhi_data/bunny" - if len(names_dw) >= 3: - for i in range(len(names_dw)-1): - assert re.split(r' |/|\\', dirname)[i] == names_dw[i] - assert filename == names_dw[i+1] - else: - assert dirname == names_dw[0] - assert filename == names_dw[1] + remote.clear_data_home(data_home=empty_data_home) + assert not isdir(empty_data_home) -def _check_fetch_output(url, filename, dirname = "test_gudhi_data", file_checksum = None): - makedirs(dirname, exist_ok=True) - path_file_dw = remote._fetch_remote(url, filename, dirname, file_checksum) - _check_dir_file_names(path_file_dw, filename, dirname) +def test_fetch_remote(): + # Test fetch with a wrong checksum + with pytest.raises(OSError): + remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "tmp_spiral_2d.npy", file_checksum = 'XXXXXXXXXX') + assert not exists("tmp_spiral_2d.npy") def _get_bunny_license_print(accept_license = False): capturedOutput = io.StringIO() # Redirect stdout sys.stdout = capturedOutput - makedirs("test_gudhi_data/bunny", exist_ok=True) + bunny_arr = remote.fetch_bunny("./tmp_for_test/bunny.npy", accept_license) + assert bunny_arr.shape == (35947, 3) + remove("./tmp_for_test/bunny.npy") - remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy", "bunny.npy", "test_gudhi_data/bunny", - '13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b', accept_license) # Reset redirect sys.stdout = sys.__stdout__ return capturedOutput -def test_fetch_remote_datasets(): - # Test fetch with a wrong checksum - with pytest.raises(OSError): - _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "spiral_2d.npy", file_checksum = 'XXXXXXXXXX') - - # Test files download from given urls with checksums provided - _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "spiral_2d.npy", - file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf') - - _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/sphere3D_pts_on_grid.off", "sphere3D_pts_on_grid.off", - file_checksum = '32f96d2cafb1177f0dd5e0a019b6ff5658e14a619a7815ae55ad0fc5e8bd3f88') - - # Test files download from given urls without checksums - _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "spiral_2d.npy") - - _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/sphere3D_pts_on_grid.off", "sphere3D_pts_on_grid.off") - - # Test printing existing LICENSE file when fetching bunny.npy with accept_license = False (default) - # Fetch LICENSE file - makedirs("test_gudhi_data/bunny", exist_ok=True) - remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE", "LICENSE", "test_gudhi_data/bunny", - 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a') - with open("test_gudhi_data/bunny/LICENSE") as f: - assert f.read().rstrip("\n") == _get_bunny_license_print().getvalue().rstrip("\n") - +def test_print_bunny_license(): # Test not printing bunny.npy LICENSE when accept_license = True assert "" == _get_bunny_license_print(accept_license = True).getvalue() - - # Remove "test_gudhi_data" directory and all its content - shutil.rmtree("test_gudhi_data") + # Test printing bunny.LICENSE file when fetching bunny.npy with accept_license = False (default) + with open("./tmp_for_test/bunny.LICENSE") as f: + assert f.read().rstrip("\n") == _get_bunny_license_print().getvalue().rstrip("\n") + shutil.rmtree("./tmp_for_test") def test_fetch_remote_datasets_wrapped(): # Check if gudhi_data default dir exists already @@ -93,27 +65,27 @@ def test_fetch_remote_datasets_wrapped(): # Check that default dir was created assert isdir(expanduser("~/gudhi_data")) + # Check downloaded files + assert exists(expanduser("~/gudhi_data/points/spiral_2d/spiral_2d.npy")) + assert exists(expanduser("~/gudhi_data/points/bunny/bunny.npy")) + assert exists(expanduser("~/gudhi_data/points/bunny/bunny.LICENSE")) # Test fetch_spiral_2d and fetch_bunny wrapping functions with data directory different from default - spiral_2d_arr = remote.fetch_spiral_2d(dirname = "./another_fetch_folder_for_test") + spiral_2d_arr = remote.fetch_spiral_2d("./another_fetch_folder_for_test/spiral_2d.npy") assert spiral_2d_arr.shape == (114562, 2) - bunny_arr = remote.fetch_bunny(dirname = "./another_fetch_folder_for_test") + bunny_arr = remote.fetch_bunny("./another_fetch_folder_for_test/bunny.npy") assert bunny_arr.shape == (35947, 3) - assert isdir(expanduser("./another_fetch_folder_for_test")) + assert isdir("./another_fetch_folder_for_test") + # Check downloaded files + assert exists("./another_fetch_folder_for_test/spiral_2d.npy") + assert exists("./another_fetch_folder_for_test/bunny.npy") + assert exists("./another_fetch_folder_for_test/bunny.LICENSE") # Remove test folders del spiral_2d_arr del bunny_arr if to_be_removed: shutil.rmtree(expanduser("~/gudhi_data")) - shutil.rmtree(expanduser("./another_fetch_folder_for_test")) - -def test_data_home(): - # Test get_data_home and clear_data_home on new empty folder - empty_data_home = remote.get_data_home(data_home="empty_folder_for_test") - assert isdir(empty_data_home) - - remote.clear_data_home(data_home=empty_data_home) - assert not isdir(empty_data_home) + shutil.rmtree("./another_fetch_folder_for_test") -- cgit v1.2.3