From 8d1e7aeb3416194d00f45587d1ecea85ba218028 Mon Sep 17 00:00:00 2001 From: Hind-M Date: Fri, 28 Jan 2022 16:21:33 +0100 Subject: Return arrays of points instead of files paths when fetching bunny.npy and spiral_2d.csv --- src/python/gudhi/datasets/remote.py | 83 +++++++++++++++++++++------------ src/python/test/test_remote_datasets.py | 33 +++++++------ 2 files changed, 72 insertions(+), 44 deletions(-) diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py index 7e8f9ce7..ef797417 100644 --- a/src/python/gudhi/datasets/remote.py +++ b/src/python/gudhi/datasets/remote.py @@ -7,17 +7,17 @@ # Modification(s): # - YYYY/MM Author: Description of the modification -import hashlib - from os.path import join, exists from os import makedirs from urllib.request import urlretrieve +import hashlib +import numpy as np def _checksum_sha256(file_path): """ - Compute the file checksum using sha256 + Compute the file checksum using sha256. Parameters ---------- @@ -26,7 +26,7 @@ def _checksum_sha256(file_path): Returns ------- - The hex digest of file_path + The hex digest of file_path. """ sha256_hash = hashlib.sha256() chunk_size = 4096 @@ -39,9 +39,9 @@ def _checksum_sha256(file_path): sha256_hash.update(buffer) return sha256_hash.hexdigest() -def fetch(url, filename, dirname = "remote_datasets", file_checksum = None, accept_license = False): +def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = None, accept_license = False): """ - Fetch the wanted dataset from the given url and save it in file_path + Fetch the wanted dataset from the given url and save it in file_path. Parameters ---------- @@ -56,7 +56,7 @@ def fetch(url, filename, dirname = "remote_datasets", file_checksum = None, acce Default is 'None'. accept_license : boolean Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms. - Default is False + Default is False. Returns ------- @@ -66,14 +66,8 @@ def fetch(url, filename, dirname = "remote_datasets", file_checksum = None, acce file_path = join(dirname, filename) - # Check for an already existing file at file_path - if not exists(file_path): - # Create directory if not existing - if not exists(dirname): - makedirs(dirname) - - # Get the file - urlretrieve(url, file_path) + # Get the file + urlretrieve(url, file_path) if file_checksum is not None: checksum = _checksum_sha256(file_path) @@ -93,44 +87,71 @@ def fetch(url, filename, dirname = "remote_datasets", file_checksum = None, acce def fetch_spiral_2d(filename = "spiral_2d.csv", dirname = "remote_datasets"): """ - Fetch spiral_2d.csv remotely + Fetch "spiral_2d.csv" remotely. Parameters ---------- filename : string - The name to give to downloaded file. Default is "spiral_2d.csv" + The name to give to downloaded file. Default is "spiral_2d.csv". dirname : string The directory to save the file to. Default is "remote_datasets". Returns ------- - file_path: string - Full path of the created file. + points: array + Array of points stored in "spiral_2d.csv". """ - return fetch("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d.csv", filename, dirname, - '37530355d980d957c4ec06b18c775f90a91e446107d06c6201c9b4000b077f38') + file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d.csv" + file_checksum = '37530355d980d957c4ec06b18c775f90a91e446107d06c6201c9b4000b077f38' -def fetch_bunny(filename = "bunny.off", dirname = "remote_datasets/bunny", accept_license = False): + archive_path = join(dirname, filename) + + if not exists(archive_path): + # Create directory if not existing + if not exists(dirname): + makedirs(dirname) + + file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum) + + return np.loadtxt(file_path_pkl) + else: + return np.loadtxt(archive_path) + +def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accept_license = False): """ - Fetch bunny.off remotely and its LICENSE file + Fetch "bunny.npy" remotely and its LICENSE file. Parameters ---------- filename : string - The name to give to downloaded file. Default is "bunny.off" + The name to give to downloaded file. Default is "bunny.npy". dirname : string The directory to save the file to. Default is "remote_datasets/bunny". accept_license : boolean Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms. - Default is False + Default is False. Returns ------- - files_paths: list of strings - Full paths of the created file and its LICENSE. + points: array + Array of points stored in "bunny.npy". """ - return [fetch("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points//bunny/LICENSE", "LICENSE", dirname, - 'aeb1bad319b7d74fa0b8076358182f9c6b1284c67cc07dc67cbc9bc73025d956'), - fetch("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points//bunny/bunny.off", filename, dirname, - '11852d5e73e2d4bd7b86a2c5cc8a5884d0fbb72539493e8cec100ea922b19f5b', accept_license)] + file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy" + file_checksum = '13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b' + license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE" + license_checksum = 'aeb1bad319b7d74fa0b8076358182f9c6b1284c67cc07dc67cbc9bc73025d956' + + archive_path = join(dirname, filename) + + if not exists(archive_path): + # Create directory if not existing + if not exists(dirname): + makedirs(dirname) + + license_path = _fetch_remote(license_url, "LICENSE", dirname, license_checksum) + file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum, accept_license) + + return np.load(file_path_pkl, mmap_mode='r') + else: + return np.load(archive_path, mmap_mode='r') diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py index e777abc6..56a273b4 100644 --- a/src/python/test/test_remote_datasets.py +++ b/src/python/test/test_remote_datasets.py @@ -10,13 +10,14 @@ from gudhi.datasets import remote import re -import os.path +from os.path import isfile, exists +from os import makedirs import io import sys import pytest def _check_dir_file_names(path_file_dw, filename, dirname): - assert os.path.isfile(path_file_dw) + assert isfile(path_file_dw) names_dw = re.split(r' |/|\\', path_file_dw) # Case where inner directories are created in "remote_datasets/"; e.g: "remote_datasets/bunny" @@ -29,15 +30,20 @@ def _check_dir_file_names(path_file_dw, filename, dirname): assert filename == names_dw[1] def _check_fetch_output(url, filename, dirname = "remote_datasets", file_checksum = None): - path_file_dw = remote.fetch(url, filename, dirname, file_checksum) + if not exists(dirname): + makedirs(dirname) + path_file_dw = remote._fetch_remote(url, filename, dirname, file_checksum) _check_dir_file_names(path_file_dw, filename, dirname) def _get_bunny_license_print(accept_license = False): capturedOutput = io.StringIO() # Redirect stdout sys.stdout = capturedOutput - remote.fetch("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points//bunny/bunny.off", "bunny.off", "remote_datasets/bunny", - '11852d5e73e2d4bd7b86a2c5cc8a5884d0fbb72539493e8cec100ea922b19f5b', accept_license) + + if not exists("remote_datasets/bunny"): + makedirs("remote_datasets/bunny") + remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy", "bunny.npy", "remote_datasets/bunny", + '13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b', accept_license) # Reset redirect sys.stdout = sys.__stdout__ return capturedOutput @@ -60,20 +66,21 @@ def test_fetch_remote_datasets(): _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/sphere3D_pts_on_grid.off", "sphere3D_pts_on_grid.off") # Test spiral_2d.csv wrapping function - path_file_dw = remote.fetch_spiral_2d() - _check_dir_file_names(path_file_dw, 'spiral_2d.csv', 'remote_datasets') + spiral_2d_arr = remote.fetch_spiral_2d() + assert spiral_2d_arr.shape == (114562, 2) - # Test printing existing LICENSE file when fetching bunny.off with accept_license = False (default) + # Test printing existing LICENSE file when fetching bunny.npy with accept_license = False (default) # Fetch LICENSE file - remote.fetch("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points//bunny/LICENSE", "LICENSE", "remote_datasets/bunny", + if not exists("remote_datasets/bunny"): + makedirs("remote_datasets/bunny") + remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE", "LICENSE", "remote_datasets/bunny", 'aeb1bad319b7d74fa0b8076358182f9c6b1284c67cc07dc67cbc9bc73025d956') with open("remote_datasets/bunny/LICENSE") as f: assert f.read() == _get_bunny_license_print().getvalue().rstrip("\n") - # Test not printing bunny.off LICENSE when accept_license = True + # Test not printing bunny.npy LICENSE when accept_license = True assert "" == _get_bunny_license_print(accept_license = True).getvalue() # Test fetch_bunny wrapping function - path_file_dw = remote.fetch_bunny() - _check_dir_file_names(path_file_dw[0], 'LICENSE', 'remote_datasets/bunny') - _check_dir_file_names(path_file_dw[1], 'bunny.off', 'remote_datasets/bunny') + bunny_arr = remote.fetch_bunny() + assert bunny_arr.shape == (35947, 3) -- cgit v1.2.3