summaryrefslogtreecommitdiff
path: root/src/python/test
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2022-02-04 15:39:51 +0100
committerHind-M <hind.montassif@gmail.com>2022-02-04 15:39:51 +0100
commita13282e4da9910a5d2bdadf97040095ae5b7880a (patch)
treefe6f0064dfc65c3101f02f530c10d37472795918 /src/python/test
parent6109fd920ba477f89e83fea3df9803232c169463 (diff)
Store fetched datasets in user directory by default
Diffstat (limited to 'src/python/test')
-rw-r--r--src/python/test/test_remote_datasets.py31
1 files changed, 24 insertions, 7 deletions
diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py
index 93a8a982..27eb51b0 100644
--- a/src/python/test/test_remote_datasets.py
+++ b/src/python/test/test_remote_datasets.py
@@ -10,7 +10,7 @@
from gudhi.datasets import remote
import re
-from os.path import isfile, exists
+from os.path import isfile, isdir, expanduser
from os import makedirs
import io
import sys
@@ -30,8 +30,7 @@ def _check_dir_file_names(path_file_dw, filename, dirname):
assert filename == names_dw[1]
def _check_fetch_output(url, filename, dirname = "remote_datasets", file_checksum = None):
- if not exists(dirname):
- makedirs(dirname)
+ makedirs(dirname, exist_ok=True)
path_file_dw = remote._fetch_remote(url, filename, dirname, file_checksum)
_check_dir_file_names(path_file_dw, filename, dirname)
@@ -40,8 +39,7 @@ def _get_bunny_license_print(accept_license = False):
# Redirect stdout
sys.stdout = capturedOutput
- if not exists("remote_datasets/bunny"):
- makedirs("remote_datasets/bunny")
+ makedirs("remote_datasets/bunny", exist_ok=True)
remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/bunny.npy", "bunny.npy", "remote_datasets/bunny",
'13f7842ebb4b45370e50641ff28c88685703efa5faab14edf0bb7d113a965e1b', accept_license)
@@ -68,8 +66,7 @@ def test_fetch_remote_datasets():
# Test printing existing LICENSE file when fetching bunny.npy with accept_license = False (default)
# Fetch LICENSE file
- if not exists("remote_datasets/bunny"):
- makedirs("remote_datasets/bunny")
+ makedirs("remote_datasets/bunny", exist_ok=True)
remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/bunny/LICENSE", "LICENSE", "remote_datasets/bunny",
'b763dbe1b2fc6015d05cbf7bcc686412a2eb100a1f2220296e3b4a644c69633a')
with open("remote_datasets/bunny/LICENSE") as f:
@@ -88,3 +85,23 @@ def test_fetch_remote_datasets():
bunny_arr = remote.fetch_bunny()
assert bunny_arr.shape == (35947, 3)
+
+ # Check that default dir was created
+ assert isdir(expanduser("~/remote_datasets")) == True
+
+ # Test clear_data_home
+ clear_data_home()
+ assert isdir(expanduser("~/remote_datasets")) == False
+
+ # Test fetch_spiral_2d and fetch_bunny wrapping functions with data directory different from default
+ spiral_2d_arr = remote.fetch_spiral_2d(dirname = "~/test")
+ assert spiral_2d_arr.shape == (114562, 2)
+
+ bunny_arr = remote.fetch_bunny(dirname = "~/test")
+ assert bunny_arr.shape == (35947, 3)
+
+ assert isdir(expanduser("~/test")) == True
+
+ # Test clear_data_home with data directory different from default
+ clear_data_home("~/test")
+ assert isdir(expanduser("~/test")) == False