diff options
Diffstat (limited to 'src/python/test')
-rw-r--r-- | src/python/test/test_diff.py | 78 | ||||
-rwxr-xr-x | src/python/test/test_dtm.py | 16 | ||||
-rw-r--r-- | src/python/test/test_remote_datasets.py | 87 | ||||
-rwxr-xr-x | src/python/test/test_representations.py | 21 | ||||
-rwxr-xr-x | src/python/test/test_simplex_tree.py | 6 | ||||
-rwxr-xr-x | src/python/test/test_subsampling.py | 4 |
6 files changed, 200 insertions, 12 deletions
diff --git a/src/python/test/test_diff.py b/src/python/test/test_diff.py new file mode 100644 index 00000000..dca001a9 --- /dev/null +++ b/src/python/test/test_diff.py @@ -0,0 +1,78 @@ +from gudhi.tensorflow import * +import numpy as np +import tensorflow as tf +import gudhi as gd + +def test_rips_diff(): + + Xinit = np.array([[1.,1.],[2.,2.]], dtype=np.float32) + X = tf.Variable(initial_value=Xinit, trainable=True) + rl = RipsLayer(maximum_edge_length=2., homology_dimensions=[0]) + + with tf.GradientTape() as tape: + dgm = rl.call(X)[0][0] + loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0]))) + grads = tape.gradient(loss, [X]) + assert tf.norm(grads[0]-tf.constant([[-.5,-.5],[.5,.5]]),1) <= 1e-6 + +def test_cubical_diff(): + + Xinit = np.array([[0.,2.,2.],[2.,2.,2.],[2.,2.,1.]], dtype=np.float32) + X = tf.Variable(initial_value=Xinit, trainable=True) + cl = CubicalLayer(homology_dimensions=[0]) + + with tf.GradientTape() as tape: + dgm = cl.call(X)[0][0] + loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0]))) + grads = tape.gradient(loss, [X]) + assert tf.norm(grads[0]-tf.constant([[0.,0.,0.],[0.,.5,0.],[0.,0.,-.5]]),1) <= 1e-6 + +def test_nonsquare_cubical_diff(): + + Xinit = np.array([[-1.,1.,0.],[1.,1.,1.]], dtype=np.float32) + X = tf.Variable(initial_value=Xinit, trainable=True) + cl = CubicalLayer(homology_dimensions=[0]) + + with tf.GradientTape() as tape: + dgm = cl.call(X)[0][0] + loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0]))) + grads = tape.gradient(loss, [X]) + assert tf.norm(grads[0]-tf.constant([[0.,0.5,-0.5],[0.,0.,0.]]),1) <= 1e-6 + +def test_st_diff(): + + st = gd.SimplexTree() + st.insert([0]) + st.insert([1]) + st.insert([2]) + st.insert([3]) + st.insert([4]) + st.insert([5]) + st.insert([6]) + st.insert([7]) + st.insert([8]) + st.insert([9]) + st.insert([10]) + st.insert([0, 1]) + st.insert([1, 2]) + st.insert([2, 3]) + st.insert([3, 4]) + st.insert([4, 5]) + st.insert([5, 6]) + st.insert([6, 7]) + st.insert([7, 8]) + st.insert([8, 9]) + st.insert([9, 10]) + + Finit = np.array([6.,4.,3.,4.,5.,4.,3.,2.,3.,4.,5.], dtype=np.float32) + F = tf.Variable(initial_value=Finit, trainable=True) + sl = LowerStarSimplexTreeLayer(simplextree=st, homology_dimensions=[0]) + + with tf.GradientTape() as tape: + dgm = sl.call(F)[0][0] + loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0]))) + grads = tape.gradient(loss, [F]) + + assert tf.math.reduce_all(tf.math.equal(grads[0].indices, tf.constant([2,4]))) + assert tf.math.reduce_all(tf.math.equal(grads[0].values, tf.constant([-1.,1.]))) + diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py index e46d616c..b276f041 100755 --- a/src/python/test/test_dtm.py +++ b/src/python/test/test_dtm.py @@ -91,11 +91,11 @@ def test_density(): def test_dtm_overflow_warnings(): pts = numpy.array([[10., 100000000000000000000000000000.], [1000., 100000000000000000000000000.]]) - - with warnings.catch_warnings(record=True) as w: - # TODO Test "keops" implementation as well when next version of pykeops (current is 1.5) is released (should fix the problem (cf. issue #543)) - dtm = DistanceToMeasure(2, implementation="hnsw") - r = dtm.fit_transform(pts) - assert len(w) == 1 - assert issubclass(w[0].category, RuntimeWarning) - assert "Overflow" in str(w[0].message) + impl_warn = ["keops", "hnsw"] + for impl in impl_warn: + with warnings.catch_warnings(record=True) as w: + dtm = DistanceToMeasure(2, implementation=impl) + r = dtm.fit_transform(pts) + assert len(w) == 1 + assert issubclass(w[0].category, RuntimeWarning) + assert "Overflow" in str(w[0].message) diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py new file mode 100644 index 00000000..e5d2de82 --- /dev/null +++ b/src/python/test/test_remote_datasets.py @@ -0,0 +1,87 @@ +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# Author(s): Hind Montassif +# +# Copyright (C) 2021 Inria +# +# Modification(s): +# - YYYY/MM Author: Description of the modification + +from gudhi.datasets import remote + +import shutil +import io +import sys +import pytest + +from os.path import isdir, expanduser, exists +from os import remove, environ + +def test_data_home(): + # Test _get_data_home and clear_data_home on new empty folder + empty_data_home = remote._get_data_home(data_home="empty_folder_for_test") + assert isdir(empty_data_home) + + remote.clear_data_home(data_home=empty_data_home) + assert not isdir(empty_data_home) + +def test_fetch_remote(): + # Test fetch with a wrong checksum + with pytest.raises(OSError): + remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "tmp_spiral_2d.npy", file_checksum = 'XXXXXXXXXX') + assert not exists("tmp_spiral_2d.npy") + +def _get_bunny_license_print(accept_license = False): + capturedOutput = io.StringIO() + # Redirect stdout + sys.stdout = capturedOutput + + bunny_arr = remote.fetch_bunny("./tmp_for_test/bunny.npy", accept_license) + assert bunny_arr.shape == (35947, 3) + del bunny_arr + remove("./tmp_for_test/bunny.npy") + + # Reset redirect + sys.stdout = sys.__stdout__ + return capturedOutput + +def test_print_bunny_license(): + # Test not printing bunny.npy LICENSE when accept_license = True + assert "" == _get_bunny_license_print(accept_license = True).getvalue() + # Test printing bunny.LICENSE file when fetching bunny.npy with accept_license = False (default) + with open("./tmp_for_test/bunny.LICENSE") as f: + assert f.read().rstrip("\n") == _get_bunny_license_print().getvalue().rstrip("\n") + shutil.rmtree("./tmp_for_test") + +def test_fetch_remote_datasets_wrapped(): + # Test fetch_spiral_2d and fetch_bunny wrapping functions with data directory different from default (twice, to test case of already fetched files) + # Default case is not tested because it would fail in case the user sets the 'GUDHI_DATA' environment variable locally + for i in range(2): + spiral_2d_arr = remote.fetch_spiral_2d("./another_fetch_folder_for_test/spiral_2d.npy") + assert spiral_2d_arr.shape == (114562, 2) + + bunny_arr = remote.fetch_bunny("./another_fetch_folder_for_test/bunny.npy") + assert bunny_arr.shape == (35947, 3) + + # Check that the directory was created + assert isdir("./another_fetch_folder_for_test") + # Check downloaded files + assert exists("./another_fetch_folder_for_test/spiral_2d.npy") + assert exists("./another_fetch_folder_for_test/bunny.npy") + assert exists("./another_fetch_folder_for_test/bunny.LICENSE") + + # Remove test folders + del spiral_2d_arr + del bunny_arr + shutil.rmtree("./another_fetch_folder_for_test") + +def test_gudhi_data_env(): + # Set environment variable "GUDHI_DATA" + environ["GUDHI_DATA"] = "./test_folder_from_env_var" + bunny_arr = remote.fetch_bunny() + assert bunny_arr.shape == (35947, 3) + assert exists("./test_folder_from_env_var/points/bunny/bunny.npy") + assert exists("./test_folder_from_env_var/points/bunny/bunny.LICENSE") + # Remove test folder + del bunny_arr + shutil.rmtree("./test_folder_from_env_var") diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py index 5f29740f..2ca72d07 100755 --- a/src/python/test/test_representations.py +++ b/src/python/test/test_representations.py @@ -152,7 +152,26 @@ def test_vectorization_empty_diagrams(): scv = Entropy(mode="vector", normalized=False, resolution=random_resolution)(empty_diag) assert not np.any(scv) assert scv.shape[0] == random_resolution - + +def test_entropy_miscalculation(): + diag_ex = np.array([[0.0,1.0], [0.0,1.0], [0.0,2.0]]) + def pe(pd): + l = pd[:,1] - pd[:,0] + l = l/sum(l) + return -np.dot(l, np.log(l)) + sce = Entropy(mode="scalar") + assert [[pe(diag_ex)]] == sce.fit_transform([diag_ex]) + sce = Entropy(mode="vector", resolution=4, normalized=False) + pef = [-1/4*np.log(1/4)-1/4*np.log(1/4)-1/2*np.log(1/2), + -1/4*np.log(1/4)-1/4*np.log(1/4)-1/2*np.log(1/2), + -1/2*np.log(1/2), + 0.0] + assert all(([pef] == sce.fit_transform([diag_ex]))[0]) + sce = Entropy(mode="vector", resolution=4, normalized=True) + pefN = (sce.fit_transform([diag_ex]))[0] + area = np.linalg.norm(pefN, ord=1) + assert area==1 + def test_kernel_empty_diagrams(): empty_diag = np.empty(shape = [0, 2]) assert SlicedWassersteinDistance(num_directions=100)(empty_diag, empty_diag) == 0. diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 688f4fd6..54bafed5 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -320,6 +320,10 @@ def test_extend_filtration(): ] dgms = st.extended_persistence(min_persistence=-1.) + assert len(dgms) == 4 + # Sort by (death-birth) descending - we are only interested in those with the longest life span + for idx in range(4): + dgms[idx] = sorted(dgms[idx], key=lambda x:(-abs(x[1][0]-x[1][1]))) assert dgms[0][0][1][0] == pytest.approx(2.) assert dgms[0][0][1][1] == pytest.approx(3.) @@ -528,7 +532,7 @@ def test_expansion_with_blocker(): def blocker(simplex): try: - # Block all simplices that countains vertex 6 + # Block all simplices that contain vertex 6 simplex.index(6) print(simplex, ' is blocked') return True diff --git a/src/python/test/test_subsampling.py b/src/python/test/test_subsampling.py index 4019852e..3431f372 100755 --- a/src/python/test/test_subsampling.py +++ b/src/python/test/test_subsampling.py @@ -91,7 +91,7 @@ def test_simple_choose_n_farthest_points_randomed(): assert gudhi.choose_n_farthest_points(points=[], nb_points=1) == [] assert gudhi.choose_n_farthest_points(points=point_set, nb_points=0) == [] - # Go furter than point set on purpose + # Go further than point set on purpose for iter in range(1, 10): sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=iter) for sub in sub_set: @@ -117,7 +117,7 @@ def test_simple_pick_n_random_points(): assert gudhi.pick_n_random_points(points=[], nb_points=1) == [] assert gudhi.pick_n_random_points(points=point_set, nb_points=0) == [] - # Go furter than point set on purpose + # Go further than point set on purpose for iter in range(1, 10): sub_set = gudhi.pick_n_random_points(points=point_set, nb_points=iter) for sub in sub_set: |