summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2022-03-02 17:58:39 +0100
committerHind-M <hind.montassif@gmail.com>2022-03-02 17:58:39 +0100
commit58e2f677081b4e9f21c47d6286b329218aa825d6 (patch)
tree9b6075210f624093785472f0bd9ce7c469360df3 /src
parent5c0c731fdd2bc41c2a4833be1612dca5a082c337 (diff)
Remove file when given checksum does not match
Add more details to doc Remove default dirname value in _fetch_remote Add points/ subfolder in fetching functions
Diffstat (limited to 'src')
-rw-r--r--src/python/gudhi/datasets/remote.py25
1 files changed, 14 insertions, 11 deletions
diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py
index 618fa80e..8b3baef4 100644
--- a/src/python/gudhi/datasets/remote.py
+++ b/src/python/gudhi/datasets/remote.py
@@ -8,7 +8,7 @@
# - YYYY/MM Author: Description of the modification
from os.path import join, exists, expanduser
-from os import makedirs
+from os import makedirs, remove
from urllib.request import urlretrieve
import hashlib
@@ -77,7 +77,7 @@ def _checksum_sha256(file_path):
sha256_hash.update(buffer)
return sha256_hash.hexdigest()
-def _fetch_remote(url, filename, dirname = "gudhi_data", file_checksum = None, accept_license = False):
+def _fetch_remote(url, filename, dirname, file_checksum = None, accept_license = False):
"""
Fetch the wanted dataset from the given url and save it in file_path.
@@ -88,7 +88,7 @@ def _fetch_remote(url, filename, dirname = "gudhi_data", file_checksum = None, a
filename : string
The name to give to downloaded file.
dirname : string
- The directory to save the file to. Default is "gudhi_data".
+ The directory to save the file to.
file_checksum : string
The file checksum using sha256 to check against the one computed on the downloaded file.
Default is 'None', which means the checksum is not checked.
@@ -115,6 +115,8 @@ def _fetch_remote(url, filename, dirname = "gudhi_data", file_checksum = None, a
if file_checksum is not None:
checksum = _checksum_sha256(file_path)
if file_checksum != checksum:
+ # Remove file and raise error
+ remove(file_path)
raise IOError("{} has a SHA256 checksum : {}, "
"different from expected : {}."
"The file may be corrupted or the given url may be wrong !".format(file_path, checksum, file_checksum))
@@ -148,17 +150,17 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None):
filename : string
The name to give to downloaded file. Default is "spiral_2d.npy".
dirname : string
- The directory to save the file to. Default is None, meaning that the data home will be set to "~/gudhi_data/spiral_2d".
+ The directory to save the file to. Default is None, meaning that the downloaded file will be put in "~/gudhi_data/points/spiral_2d".
Returns
-------
- points: array
- Array of points.
+ points: numpy array
+ Array of shape (114562, 2).
"""
file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy"
file_checksum = '88312ffd6df2e2cb2bde9c0e1f962d7d644c6f58dc369c7b377b298dacdc4eaf'
- archive_path, dirname = _get_archive_and_dir(dirname, filename, "spiral_2d")
+ archive_path, dirname = _get_archive_and_dir(dirname, filename, "points/spiral_2d")
if not exists(archive_path):
file_path_pkl = _fetch_remote(file_url, filename, dirname, file_checksum)
@@ -170,21 +172,22 @@ def fetch_spiral_2d(filename = "spiral_2d.npy", dirname = None):
def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False):
"""
Fetch Stanford bunny dataset remotely and its LICENSE file.
+ This dataset contains 35947 vertices.
Parameters
----------
filename : string
The name to give to downloaded file. Default is "bunny.npy".
dirname : string
- The directory to save the file to. Default is None, meaning that the data home will be set to "~/gudhi_data/bunny".
+ The directory to save the file to. Default is None, meaning that the downloaded files will be put in "~/gudhi_data/points/bunny".
accept_license : boolean
Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms.
Default is False.
Returns
-------
- points: array
- Array of points.
+ points: numpy array
+ Array of shape (35947, 3).
"""
file_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy"
@@ -192,7 +195,7 @@ def fetch_bunny(filename = "bunny.npy", dirname = None, accept_license = False):
license_url = "https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE"
license_checksum = 'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a'
- archive_path, dirname = _get_archive_and_dir(dirname, filename, "bunny")
+ archive_path, dirname = _get_archive_and_dir(dirname, filename, "points/bunny")
if not exists(archive_path):
license_path = _fetch_remote(license_url, "LICENSE", dirname, license_checksum)