From ad7a50fb87ed4237b9a02165eac39ae355dd5440 Mon Sep 17 00:00:00 2001 From: Hind-M Date: Tue, 1 Feb 2022 10:32:03 +0100 Subject: Fetch spiral_2d.npy file instead of csv Add some modifications related to those done on files in gudhi-data --- src/python/gudhi/datasets/remote.py | 20 ++++++++++---------- src/python/test/test_remote_datasets.py | 14 +++++++------- 2 files changed, 17 insertions(+), 17 deletions(-) (limited to 'src') diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py index ef797417..3498a645 100644 --- a/src/python/gudhi/datasets/remote.py +++ b/src/python/gudhi/datasets/remote.py @@ -85,24 +85,24 @@ def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = No return file_path -def fetch_spiral_2d(filename = "spiral_2d.csv", dirname = "remote_datasets"): +def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = "remote_datasets/spiral_2d"): """ - Fetch "spiral_2d.csv" remotely. + Fetch "spiral_2d.npy" remotely. Parameters ---------- filename : string - The name to give to downloaded file. Default is "spiral_2d.csv". + The name to give to downloaded file. Default is "spiral_2d.npy". dirname : string - The directory to save the file to. Default is "remote_datasets". + The directory to save the file to. Default is "remote_datasets/spiral_2d". Returns ------- points: array - Array of points stored in "spiral_2d.csv". + Array of points stored in "spiral_2d.npy". """ - file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d.csv" - file_checksum = '37530355d980d957c4ec06b18c775f90a91e446107d06c6201c9b4000b077f38' + file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy" + file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf' archive_path = join(dirname, filename) @@ -113,9 +113,9 @@ def fetch_spiral_2d(filename = "spiral_2d.csv", dirname = "remote_datasets"): file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum) - return np.loadtxt(file_path_pkl) + return np.load(file_path_pkl, mmap_mode='r') else: - return np.loadtxt(archive_path) + return np.load(archive_path, mmap_mode='r') def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accept_license = False): """ @@ -140,7 +140,7 @@ def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accep file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy" file_checksum = '13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b' license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE" - license_checksum = 'aeb1bad319b7d74fa0b8076358182f9c6b1284c67cc07dc67cbc9bc73025d956' + license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a' archive_path = join(dirname, filename) diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py index 56a273b4..2057c63b 100644 --- a/src/python/test/test_remote_datasets.py +++ b/src/python/test/test_remote_datasets.py @@ -51,21 +51,21 @@ def _get_bunny_license_print(accept_license = False): def test_fetch_remote_datasets(): # Test fetch with a wrong checksum with pytest.raises(OSError): - _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d.csv", "spiral_2d.csv", file_checksum = 'XXXXXXXXXX') + _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "spiral_2d.npy", file_checksum = 'XXXXXXXXXX') # Test files download from given urls with checksums provided - _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d.csv", "spiral_2d.csv", - file_checksum = '37530355d980d957c4ec06b18c775f90a91e446107d06c6201c9b4000b077f38') + _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "spiral_2d.npy", + file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf') _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/sphere3D_pts_on_grid.off", "sphere3D_pts_on_grid.off", file_checksum = '32f96d2cafb1177f0dd5e0a019b6ff5658e14a619a7815ae55ad0fc5e8bd3f88') # Test files download from given urls without checksums - _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d.csv", "spiral_2d.csv") + _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "spiral_2d.npy") _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/sphere3D_pts_on_grid.off", "sphere3D_pts_on_grid.off") - # Test spiral_2d.csv wrapping function + # Test fetch_spiral_2d wrapping function spiral_2d_arr = remote.fetch_spiral_2d() assert spiral_2d_arr.shape == (114562, 2) @@ -74,9 +74,9 @@ def test_fetch_remote_datasets(): if not exists("remote_datasets/bunny"): makedirs("remote_datasets/bunny") remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE", "LICENSE", "remote_datasets/bunny", - 'aeb1bad319b7d74fa0b8076358182f9c6b1284c67cc07dc67cbc9bc73025d956') + 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a') with open("remote_datasets/bunny/LICENSE") as f: - assert f.read() == _get_bunny_license_print().getvalue().rstrip("\n") + assert f.read().rstrip("\n") == _get_bunny_license_print().getvalue().rstrip("\n") # Test not printing bunny.npy LICENSE when accept_license = True assert "" == _get_bunny_license_print(accept_license = True).getvalue() -- cgit v1.2.3