summaryrefslogtreecommitdiff
path: root/src/python/gudhi/datasets
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2022-02-01 10:32:03 +0100
committerHind-M <hind.montassif@gmail.com>2022-02-01 10:32:03 +0100
commitad7a50fb87ed4237b9a02165eac39ae355dd5440 (patch)
treea9c901925f4bfb8fb8483c6316a91d6eae0be22f /src/python/gudhi/datasets
parent1209db091a89ed48527c34fff0cc9ef41c78d11f (diff)
Fetch spiral_2d.npy file instead of csv
Add some modifications related to those done on files in gudhi-data
Diffstat (limited to 'src/python/gudhi/datasets')
-rw-r--r--src/python/gudhi/datasets/remote.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py
index ef797417..3498a645 100644
--- a/src/python/gudhi/datasets/remote.py
+++ b/src/python/gudhi/datasets/remote.py
@@ -85,24 +85,24 @@ def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = No
return file_path
-def fetch_spiral_2d(filename = "spiral_2d.csv", dirname = "remote_datasets"):
+def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = "remote_datasets/spiral_2d"):
"""
- Fetch "spiral_2d.csv" remotely.
+ Fetch "spiral_2d.npy" remotely.
Parameters
----------
filename : string
- The name to give to downloaded file. Default is "spiral_2d.csv".
+ The name to give to downloaded file. Default is "spiral_2d.npy".
dirname : string
- The directory to save the file to. Default is "remote_datasets".
+ The directory to save the file to. Default is "remote_datasets/spiral_2d".
Returns
-------
points: array
- Array of points stored in "spiral_2d.csv".
+ Array of points stored in "spiral_2d.npy".
"""
- file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d.csv"
- file_checksum = '37530355d980d957c4ec06b18c775f90a91e446107d06c6201c9b4000b077f38'
+ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy"
+ file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf'
archive_path = join(dirname, filename)
@@ -113,9 +113,9 @@ def fetch_spiral_2d(filename = "spiral_2d.csv", dirname = "remote_datasets"):
file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum)
- return np.loadtxt(file_path_pkl)
+ return np.load(file_path_pkl, mmap_mode='r')
else:
- return np.loadtxt(archive_path)
+ return np.load(archive_path, mmap_mode='r')
def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accept_license = False):
"""
@@ -140,7 +140,7 @@ def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accep
file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy"
file_checksum = '13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b'
license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE"
- license_checksum = 'aeb1bad319b7d74fa0b8076358182f9c6b1284c67cc07dc67cbc9bc73025d956'
+ license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a'
archive_path = join(dirname, filename)