summaryrefslogtreecommitdiff
path: root/src/python/test
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2022-01-28 16:21:33 +0100
committerHind-M <hind.montassif@gmail.com>2022-01-28 16:21:33 +0100
commit8d1e7aeb3416194d00f45587d1ecea85ba218028 (patch)
tree0e56666b1b52cfdf8d08f20150cff1f44d7dbbc8 /src/python/test
parentd941ebc854880a06707999f677137a9d6ff7473f (diff)
Return arrays of points instead of files paths when fetching bunny.npy and spiral_2d.csv
Diffstat (limited to 'src/python/test')
-rw-r--r--src/python/test/test_remote_datasets.py33
1 files changed, 20 insertions, 13 deletions
diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py
index e777abc6..56a273b4 100644
--- a/src/python/test/test_remote_datasets.py
+++ b/src/python/test/test_remote_datasets.py
@@ -10,13 +10,14 @@
from gudhi.datasets import remote
import re
-import os.path
+from os.path import isfile, exists
+from os import makedirs
import io
import sys
import pytest
def _check_dir_file_names(path_file_dw, filename, dirname):
- assert os.path.isfile(path_file_dw)
+ assert isfile(path_file_dw)
names_dw = re.split(r' |/|\\', path_file_dw)
# Case where inner directories are created in "remote_datasets/"; e.g: "remote_datasets/bunny"
@@ -29,15 +30,20 @@ def _check_dir_file_names(path_file_dw, filename, dirname):
assert filename == names_dw[1]
def _check_fetch_output(url, filename, dirname = "remote_datasets", file_checksum = None):
- path_file_dw = remote.fetch(url, filename, dirname, file_checksum)
+ if not exists(dirname):
+ makedirs(dirname)
+ path_file_dw = remote._fetch_remote(url, filename, dirname, file_checksum)
_check_dir_file_names(path_file_dw, filename, dirname)
def _get_bunny_license_print(accept_license = False):
capturedOutput = io.StringIO()
# Redirect stdout
sys.stdout = capturedOutput
- remote.fetch("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points//bunny/bunny.off", "bunny.off", "remote_datasets/bunny",
- '11852d5e73e2d4bd7b86a2c5cc8a5884d0fbb72539493e8cec100ea922b19f5b', accept_license)
+
+ if not exists("remote_datasets/bunny"):
+ makedirs("remote_datasets/bunny")
+ remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy", "bunny.npy", "remote_datasets/bunny",
+ '13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b', accept_license)
# Reset redirect
sys.stdout = sys.__stdout__
return capturedOutput
@@ -60,20 +66,21 @@ def test_fetch_remote_datasets():
_check_fetch_output("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/sphere3D_pts_on_grid.off", "sphere3D_pts_on_grid.off")
# Test spiral_2d.csv wrapping function
- path_file_dw = remote.fetch_spiral_2d()
- _check_dir_file_names(path_file_dw, 'spiral_2d.csv', 'remote_datasets')
+ spiral_2d_arr = remote.fetch_spiral_2d()
+ assert spiral_2d_arr.shape == (114562, 2)
- # Test printing existing LICENSE file when fetching bunny.off with accept_license = False (default)
+ # Test printing existing LICENSE file when fetching bunny.npy with accept_license = False (default)
# Fetch LICENSE file
- remote.fetch("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points//bunny/LICENSE", "LICENSE", "remote_datasets/bunny",
+ if not exists("remote_datasets/bunny"):
+ makedirs("remote_datasets/bunny")
+ remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE", "LICENSE", "remote_datasets/bunny",
'aeb1bad319b7d74fa0b8076358182f9c6b1284c67cc07dc67cbc9bc73025d956')
with open("remote_datasets/bunny/LICENSE") as f:
assert f.read() == _get_bunny_license_print().getvalue().rstrip("\n")
- # Test not printing bunny.off LICENSE when accept_license = True
+ # Test not printing bunny.npy LICENSE when accept_license = True
assert "" == _get_bunny_license_print(accept_license = True).getvalue()
# Test fetch_bunny wrapping function
- path_file_dw = remote.fetch_bunny()
- _check_dir_file_names(path_file_dw[0], 'LICENSE', 'remote_datasets/bunny')
- _check_dir_file_names(path_file_dw[1], 'bunny.off', 'remote_datasets/bunny')
+ bunny_arr = remote.fetch_bunny()
+ assert bunny_arr.shape == (35947, 3)