summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2022-02-01 10:32:03 +0100
committerHind-M <hind.montassif@gmail.com>2022-02-01 10:32:03 +0100
commitad7a50fb87ed4237b9a02165eac39ae355dd5440 (patch)
treea9c901925f4bfb8fb8483c6316a91d6eae0be22f /src
parent1209db091a89ed48527c34fff0cc9ef41c78d11f (diff)
Fetch spiral_2d.npy file instead of csv
Add some modifications related to those done on files in gudhi-data
Diffstat (limited to 'src')
-rw-r--r--src/python/gudhi/datasets/remote.py20
-rw-r--r--src/python/test/test_remote_datasets.py14
2 files changed, 17 insertions, 17 deletions
diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py
index ef797417..3498a645 100644
--- a/src/python/gudhi/datasets/remote.py
+++ b/src/python/gudhi/datasets/remote.py
@@ -85,24 +85,24 @@ def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = No
return file_path
-def fetch_spiral_2d(filename = "spiral_2d.csv", dirname = "remote_datasets"):
+def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = "remote_datasets/spiral_2d"):
"""
- Fetch "spiral_2d.csv" remotely.
+ Fetch "spiral_2d.npy" remotely.
Parameters
----------
filename : string
- The name to give to downloaded file. Default is "spiral_2d.csv".
+ The name to give to downloaded file. Default is "spiral_2d.npy".
dirname : string
- The directory to save the file to. Default is "remote_datasets".
+ The directory to save the file to. Default is "remote_datasets/spiral_2d".
Returns
-------
points: array
- Array of points stored in "spiral_2d.csv".
+ Array of points stored in "spiral_2d.npy".
"""
- file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d.csv"
- file_checksum = '37530355d980d957c4ec06b18c775f90a91e446107d06c6201c9b4000b077f38'
+ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy"
+ file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf'
archive_path = join(dirname, filename)
@@ -113,9 +113,9 @@ def fetch_spiral_2d(filename = "spiral_2d.csv", dirname = "remote_datasets"):
file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum)
- return np.loadtxt(file_path_pkl)
+ return np.load(file_path_pkl, mmap_mode='r')
else:
- return np.loadtxt(archive_path)
+ return np.load(archive_path, mmap_mode='r')
def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accept_license = False):
"""
@@ -140,7 +140,7 @@ def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accep
file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy"
file_checksum = '13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b'
license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE"
- license_checksum = 'aeb1bad319b7d74fa0b8076358182f9c6b1284c67cc07dc67cbc9bc73025d956'
+ license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a'
archive_path = join(dirname, filename)
diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py
index 56a273b4..2057c63b 100644
--- a/src/python/test/test_remote_datasets.py
+++ b/src/python/test/test_remote_datasets.py
@@ -51,21 +51,21 @@ def _get_bunny_license_print(accept_license = False):
def test_fetch_remote_datasets():
# Test fetch with a wrong checksum
with pytest.raises(OSError):
- _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d.csv", "spiral_2d.csv", file_checksum = 'XXXXXXXXXX')
+ _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "spiral_2d.npy", file_checksum = 'XXXXXXXXXX')
# Test files download from given urls with checksums provided
- _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d.csv", "spiral_2d.csv",
- file_checksum = '37530355d980d957c4ec06b18c775f90a91e446107d06c6201c9b4000b077f38')
+ _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "spiral_2d.npy",
+ file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf')
_check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/sphere3D_pts_on_grid.off", "sphere3D_pts_on_grid.off",
file_checksum = '32f96d2cafb1177f0dd5e0a019b6ff5658e14a619a7815ae55ad0fc5e8bd3f88')
# Test files download from given urls without checksums
- _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d.csv", "spiral_2d.csv")
+ _check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "spiral_2d.npy")
_check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/sphere3D_pts_on_grid.off", "sphere3D_pts_on_grid.off")
- # Test spiral_2d.csv wrapping function
+ # Test fetch_spiral_2d wrapping function
spiral_2d_arr = remote.fetch_spiral_2d()
assert spiral_2d_arr.shape == (114562, 2)
@@ -74,9 +74,9 @@ def test_fetch_remote_datasets():
if not exists("remote_datasets/bunny"):
makedirs("remote_datasets/bunny")
remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE", "LICENSE", "remote_datasets/bunny",
- 'aeb1bad319b7d74fa0b8076358182f9c6b1284c67cc07dc67cbc9bc73025d956')
+ 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a')
with open("remote_datasets/bunny/LICENSE") as f:
- assert f.read() == _get_bunny_license_print().getvalue().rstrip("\n")
+ assert f.read().rstrip("\n") == _get_bunny_license_print().getvalue().rstrip("\n")
# Test not printing bunny.npy LICENSE when accept_license = True
assert "" == _get_bunny_license_print(accept_license = True).getvalue()