summaryrefslogtreecommitdiff
path: root/src/python/gudhi/datasets
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2022-01-28 16:21:33 +0100
committerHind-M <hind.montassif@gmail.com>2022-01-28 16:21:33 +0100
commit8d1e7aeb3416194d00f45587d1ecea85ba218028 (patch)
tree0e56666b1b52cfdf8d08f20150cff1f44d7dbbc8 /src/python/gudhi/datasets
parentd941ebc854880a06707999f677137a9d6ff7473f (diff)
Return arrays of points instead of files paths when fetching bunny.npy and spiral_2d.csv
Diffstat (limited to 'src/python/gudhi/datasets')
-rw-r--r--src/python/gudhi/datasets/remote.py83
1 files changed, 52 insertions, 31 deletions
diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py
index 7e8f9ce7..ef797417 100644
--- a/src/python/gudhi/datasets/remote.py
+++ b/src/python/gudhi/datasets/remote.py
@@ -7,17 +7,17 @@
# Modification(s):
# - YYYY/MM Author: Description of the modification
-import hashlib
-
from os.path import join, exists
from os import makedirs
from urllib.request import urlretrieve
+import hashlib
+import numpy as np
def _checksum_sha256(file_path):
"""
- Compute the file checksum using sha256
+ Compute the file checksum using sha256.
Parameters
----------
@@ -26,7 +26,7 @@ def _checksum_sha256(file_path):
Returns
-------
- The hex digest of file_path
+ The hex digest of file_path.
"""
sha256_hash = hashlib.sha256()
chunk_size = 4096
@@ -39,9 +39,9 @@ def _checksum_sha256(file_path):
sha256_hash.update(buffer)
return sha256_hash.hexdigest()
-def fetch(url, filename, dirname = "remote_datasets", file_checksum = None, accept_license = False):
+def _fetch_remote(url, filename, dirname = "remote_datasets", file_checksum = None, accept_license = False):
"""
- Fetch the wanted dataset from the given url and save it in file_path
+ Fetch the wanted dataset from the given url and save it in file_path.
Parameters
----------
@@ -56,7 +56,7 @@ def fetch(url, filename, dirname = "remote_datasets", file_checksum = None, acce
Default is 'None'.
accept_license : boolean
Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms.
- Default is False
+ Default is False.
Returns
-------
@@ -66,14 +66,8 @@ def fetch(url, filename, dirname = "remote_datasets", file_checksum = None, acce
file_path = join(dirname, filename)
- # Check for an already existing file at file_path
- if not exists(file_path):
- # Create directory if not existing
- if not exists(dirname):
- makedirs(dirname)
-
- # Get the file
- urlretrieve(url, file_path)
+ # Get the file
+ urlretrieve(url, file_path)
if file_checksum is not None:
checksum = _checksum_sha256(file_path)
@@ -93,44 +87,71 @@ def fetch(url, filename, dirname = "remote_datasets", file_checksum = None, acce
def fetch_spiral_2d(filename = "spiral_2d.csv", dirname = "remote_datasets"):
"""
- Fetch spiral_2d.csv remotely
+ Fetch "spiral_2d.csv" remotely.
Parameters
----------
filename : string
- The name to give to downloaded file. Default is "spiral_2d.csv"
+ The name to give to downloaded file. Default is "spiral_2d.csv".
dirname : string
The directory to save the file to. Default is "remote_datasets".
Returns
-------
- file_path: string
- Full path of the created file.
+ points: array
+ Array of points stored in "spiral_2d.csv".
"""
- return fetch("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d.csv", filename, dirname,
- '37530355d980d957c4ec06b18c775f90a91e446107d06c6201c9b4000b077f38')
+ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d.csv"
+ file_checksum = '37530355d980d957c4ec06b18c775f90a91e446107d06c6201c9b4000b077f38'
-def fetch_bunny(filename = "bunny.off", dirname = "remote_datasets/bunny", accept_license = False):
+ archive_path = join(dirname, filename)
+
+ if not exists(archive_path):
+ # Create directory if not existing
+ if not exists(dirname):
+ makedirs(dirname)
+
+ file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum)
+
+ return np.loadtxt(file_path_pkl)
+ else:
+ return np.loadtxt(archive_path)
+
+def fetch_bunny(filename = "bunny.npy", dirname = "remote_datasets/bunny", accept_license = False):
"""
- Fetch bunny.off remotely and its LICENSE file
+ Fetch "bunny.npy" remotely and its LICENSE file.
Parameters
----------
filename : string
- The name to give to downloaded file. Default is "bunny.off"
+ The name to give to downloaded file. Default is "bunny.npy".
dirname : string
The directory to save the file to. Default is "remote_datasets/bunny".
accept_license : boolean
Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms.
- Default is False
+ Default is False.
Returns
-------
- files_paths: list of strings
- Full paths of the created file and its LICENSE.
+ points: array
+ Array of points stored in "bunny.npy".
"""
- return [fetch("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points//bunny/LICENSE", "LICENSE", dirname,
- 'aeb1bad319b7d74fa0b8076358182f9c6b1284c67cc07dc67cbc9bc73025d956'),
- fetch("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points//bunny/bunny.off", filename, dirname,
- '11852d5e73e2d4bd7b86a2c5cc8a5884d0fbb72539493e8cec100ea922b19f5b', accept_license)]
+ file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy"
+ file_checksum = '13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b'
+ license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE"
+ license_checksum = 'aeb1bad319b7d74fa0b8076358182f9c6b1284c67cc07dc67cbc9bc73025d956'
+
+ archive_path = join(dirname, filename)
+
+ if not exists(archive_path):
+ # Create directory if not existing
+ if not exists(dirname):
+ makedirs(dirname)
+
+ license_path = _fetch_remote(license_url, "LICENSE", dirname, license_checksum)
+ file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum, accept_license)
+
+ return np.load(file_path_pkl, mmap_mode='r')
+ else:
+ return np.load(archive_path, mmap_mode='r')