summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2022-05-05 17:43:12 +0200
committerHind-M <hind.montassif@gmail.com>2022-05-05 17:43:12 +0200
commit52d5b524403a43bfdc0b27a7feeec04e9c9c34c2 (patch)
tree803334833f94658f4c8e448819eb5708dee8b64c /src/python
parentef8284cce27a8f11947e7f076034aa2fd8b5a395 (diff)
Add GUDHI_DATA environment variable option
Diffstat (limited to 'src/python')
-rw-r--r--src/python/gudhi/datasets/remote.py16
-rw-r--r--src/python/test/test_remote_datasets.py13
2 files changed, 23 insertions, 6 deletions
diff --git a/src/python/gudhi/datasets/remote.py b/src/python/gudhi/datasets/remote.py
index 5b535911..eac8caf3 100644
--- a/src/python/gudhi/datasets/remote.py
+++ b/src/python/gudhi/datasets/remote.py
@@ -8,7 +8,7 @@
# - YYYY/MM Author: Description of the modification
from os.path import join, split, exists, expanduser
-from os import makedirs, remove
+from os import makedirs, remove, environ
from urllib.request import urlretrieve
import hashlib
@@ -21,13 +21,16 @@ def get_data_home(data_home = None):
Return the path of the remote datasets directory.
This folder is used to store remotely fetched datasets.
By default the datasets directory is set to a folder named 'gudhi_data' in the user home folder.
- Alternatively, it can be set by giving an explicit folder path. The '~' symbol is expanded to the user home folder.
+ Alternatively, it can be set by the 'GUDHI_DATA' environment variable.
+ The '~' symbol is expanded to the user home folder.
If the folder does not already exist, it is automatically created.
Parameters
----------
data_home : string
- The path to remote datasets directory. Default is `None`, meaning that the data home directory will be set to "~/gudhi_data".
+ The path to remote datasets directory.
+ Default is `None`, meaning that the data home directory will be set to "~/gudhi_data",
+ if the 'GUDHI_DATA' environment variable does not exist.
Returns
-------
@@ -35,7 +38,7 @@ def get_data_home(data_home = None):
The path to remote datasets directory.
"""
if data_home is None:
- data_home = join("~", "gudhi_data")
+ data_home = environ.get("GUDHI_DATA", join("~", "gudhi_data"))
data_home = expanduser(data_home)
makedirs(data_home, exist_ok=True)
return data_home
@@ -48,7 +51,9 @@ def clear_data_home(data_home = None):
Parameters
----------
data_home : string, default is None.
- The path to remote datasets directory. If `None`, the default directory to be removed is set to "~/gudhi_data".
+ The path to remote datasets directory.
+ If `None` and the 'GUDHI_DATA' environment variable does not exist,
+ the default directory to be removed is set to "~/gudhi_data".
"""
data_home = get_data_home(data_home)
shutil.rmtree(data_home)
@@ -170,6 +175,7 @@ def fetch_bunny(file_path = None, accept_license = False):
file_path : string
Full path of the downloaded file including filename.
Default is None, meaning that it's set to "data_home/points/bunny/bunny.npy".
+ In this case, the LICENSE file would be downloaded as "data_home/points/bunny/bunny.LICENSE".
accept_license : boolean
Flag to specify if user accepts the file LICENSE and prevents from printing the corresponding license terms.
Default is False.
diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py
index 5d0d397d..af26d77c 100644
--- a/src/python/test/test_remote_datasets.py
+++ b/src/python/test/test_remote_datasets.py
@@ -15,7 +15,7 @@ import sys
import pytest
from os.path import isdir, expanduser, exists
-from os import remove
+from os import remove, environ
def test_data_home():
# Test get_data_home and clear_data_home on new empty folder
@@ -89,3 +89,14 @@ def test_fetch_remote_datasets_wrapped():
if to_be_removed:
shutil.rmtree(expanduser("~/gudhi_data"))
shutil.rmtree("./another_fetch_folder_for_test")
+
+def test_gudhi_data_env():
+ # Set environment variable "GUDHI_DATA"
+ environ["GUDHI_DATA"] = "./test_folder_from_env_var"
+ bunny_arr = remote.fetch_bunny()
+ assert bunny_arr.shape == (35947, 3)
+ assert exists("./test_folder_from_env_var/points/bunny/bunny.npy")
+ assert exists("./test_folder_from_env_var/points/bunny/bunny.LICENSE")
+ # Remove test folder
+ del bunny_arr
+ shutil.rmtree("./test_folder_from_env_var")