summaryrefslogtreecommitdiff
path: root/src/python/test/test_remote_datasets.py
blob: e5d2de82e1d902d7bdb57ca13c8b470faf2f3835 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
# Author(s):       Hind Montassif
#
# Copyright (C) 2021 Inria
#
# Modification(s):
#   - YYYY/MM Author: Description of the modification

from gudhi.datasets import remote

import shutil
import io
import sys
import pytest

from os.path import isdir, expanduser, exists
from os import remove, environ

def test_data_home():
    # Test _get_data_home and clear_data_home on new empty folder
    empty_data_home = remote._get_data_home(data_home="empty_folder_for_test")
    assert isdir(empty_data_home)

    remote.clear_data_home(data_home=empty_data_home)
    assert not isdir(empty_data_home)

def test_fetch_remote():
    # Test fetch with a wrong checksum
    with pytest.raises(OSError):
        remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "tmp_spiral_2d.npy", file_checksum = 'XXXXXXXXXX')
    assert not exists("tmp_spiral_2d.npy")

def _get_bunny_license_print(accept_license = False):
    capturedOutput = io.StringIO()
    # Redirect stdout
    sys.stdout = capturedOutput

    bunny_arr = remote.fetch_bunny("./tmp_for_test/bunny.npy", accept_license)
    assert bunny_arr.shape == (35947, 3)
    del bunny_arr
    remove("./tmp_for_test/bunny.npy")

    # Reset redirect
    sys.stdout = sys.__stdout__
    return capturedOutput

def test_print_bunny_license():
    # Test not printing bunny.npy LICENSE when accept_license = True
    assert "" == _get_bunny_license_print(accept_license = True).getvalue()
    # Test printing bunny.LICENSE file when fetching bunny.npy with accept_license = False (default)
    with open("./tmp_for_test/bunny.LICENSE") as f:
        assert f.read().rstrip("\n") == _get_bunny_license_print().getvalue().rstrip("\n")
    shutil.rmtree("./tmp_for_test")

def test_fetch_remote_datasets_wrapped():
    # Test fetch_spiral_2d and fetch_bunny wrapping functions with data directory different from default (twice, to test case of already fetched files)
    # Default case is not tested because it would fail in case the user sets the 'GUDHI_DATA' environment variable locally
    for i in range(2):
        spiral_2d_arr = remote.fetch_spiral_2d("./another_fetch_folder_for_test/spiral_2d.npy")
        assert spiral_2d_arr.shape == (114562, 2)

        bunny_arr = remote.fetch_bunny("./another_fetch_folder_for_test/bunny.npy")
        assert bunny_arr.shape == (35947, 3)

    # Check that the directory was created
    assert isdir("./another_fetch_folder_for_test")
    # Check downloaded files
    assert exists("./another_fetch_folder_for_test/spiral_2d.npy")
    assert exists("./another_fetch_folder_for_test/bunny.npy")
    assert exists("./another_fetch_folder_for_test/bunny.LICENSE")

    # Remove test folders
    del spiral_2d_arr
    del bunny_arr
    shutil.rmtree("./another_fetch_folder_for_test")

def test_gudhi_data_env():
    # Set environment variable "GUDHI_DATA"
    environ["GUDHI_DATA"] = "./test_folder_from_env_var"
    bunny_arr = remote.fetch_bunny()
    assert bunny_arr.shape == (35947, 3)
    assert exists("./test_folder_from_env_var/points/bunny/bunny.npy")
    assert exists("./test_folder_from_env_var/points/bunny/bunny.LICENSE")
    # Remove test folder
    del bunny_arr
    shutil.rmtree("./test_folder_from_env_var")