summaryrefslogtreecommitdiff
path: root/src/python/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test')
-rwxr-xr-xsrc/python/test/test_alpha_complex.py251
-rwxr-xr-xsrc/python/test/test_betti_curve_representations.py59
-rwxr-xr-xsrc/python/test/test_bottleneck_distance.py18
-rwxr-xr-xsrc/python/test/test_cover_complex.py4
-rwxr-xr-xsrc/python/test/test_cubical_complex.py58
-rwxr-xr-xsrc/python/test/test_datasets_generators.py39
-rw-r--r--src/python/test/test_diff.py78
-rwxr-xr-xsrc/python/test/test_dtm.py101
-rw-r--r--src/python/test/test_dtm_rips_complex.py32
-rwxr-xr-xsrc/python/test/test_euclidean_witness_complex.py6
-rwxr-xr-xsrc/python/test/test_knn.py130
-rw-r--r--src/python/test/test_off.py21
-rw-r--r--src/python/test/test_persistence_graphical_tools.py122
-rwxr-xr-xsrc/python/test/test_reader_utils.py35
-rw-r--r--src/python/test/test_remote_datasets.py87
-rwxr-xr-xsrc/python/test/test_representations.py259
-rw-r--r--src/python/test/test_representations_preprocessing.py39
-rwxr-xr-xsrc/python/test/test_rips_complex.py27
-rwxr-xr-xsrc/python/test/test_simplex_generators.py64
-rwxr-xr-xsrc/python/test/test_simplex_tree.py414
-rw-r--r--src/python/test/test_sklearn_cubical_persistence.py59
-rwxr-xr-xsrc/python/test/test_subsampling.py120
-rwxr-xr-xsrc/python/test/test_tangential_complex.py3
-rwxr-xr-xsrc/python/test/test_time_delay.py43
-rwxr-xr-xsrc/python/test/test_tomato.py65
-rwxr-xr-xsrc/python/test/test_wasserstein_barycenter.py46
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py187
-rwxr-xr-xsrc/python/test/test_wasserstein_with_tensors.py47
-rw-r--r--src/python/test/test_weighted_rips_complex.py63
29 files changed, 2294 insertions, 183 deletions
diff --git a/src/python/test/test_alpha_complex.py b/src/python/test/test_alpha_complex.py
index 3761fe16..f81e6137 100755
--- a/src/python/test/test_alpha_complex.py
+++ b/src/python/test/test_alpha_complex.py
@@ -8,10 +8,12 @@
- YYYY/MM Author: Description of the modification
"""
-from gudhi import AlphaComplex, SimplexTree
+from gudhi import AlphaComplex
import math
import numpy as np
import pytest
+import warnings
+
try:
# python3
from itertools import zip_longest
@@ -19,19 +21,24 @@ except ImportError:
# python2
from itertools import izip_longest as zip_longest
-__author__ = "Vincent Rouvreau"
-__copyright__ = "Copyright (C) 2016 Inria"
-__license__ = "MIT"
-def test_empty_alpha():
- alpha_complex = AlphaComplex(points=[[0, 0]])
+def _empty_alpha(precision):
+ alpha_complex = AlphaComplex(precision = precision)
assert alpha_complex.__is_defined() == True
+def _one_2d_point_alpha(precision):
+ alpha_complex = AlphaComplex(points=[[0, 0]], precision = precision)
+ assert alpha_complex.__is_defined() == True
-def test_infinite_alpha():
+def test_empty_alpha():
+ for precision in ['fast', 'safe', 'exact']:
+ _empty_alpha(precision)
+ _one_2d_point_alpha(precision)
+
+def _infinite_alpha(precision):
point_list = [[0, 0], [1, 0], [0, 1], [1, 1]]
- alpha_complex = AlphaComplex(points=point_list)
+ alpha_complex = AlphaComplex(points=point_list, precision = precision)
assert alpha_complex.__is_defined() == True
simplex_tree = alpha_complex.create_simplex_tree()
@@ -40,7 +47,7 @@ def test_infinite_alpha():
assert simplex_tree.num_simplices() == 11
assert simplex_tree.num_vertices() == 4
- assert simplex_tree.get_filtration() == [
+ assert list(simplex_tree.get_filtration()) == [
([0], 0.0),
([1], 0.0),
([2], 0.0),
@@ -53,6 +60,7 @@ def test_infinite_alpha():
([0, 1, 2], 0.5),
([1, 2, 3], 0.5),
]
+
assert simplex_tree.get_star([0]) == [
([0], 0.0),
([0, 1], 0.25),
@@ -65,23 +73,17 @@ def test_infinite_alpha():
assert point_list[1] == alpha_complex.get_point(1)
assert point_list[2] == alpha_complex.get_point(2)
assert point_list[3] == alpha_complex.get_point(3)
- try:
- alpha_complex.get_point(4) == []
- except IndexError:
- pass
- else:
- assert False
- try:
- alpha_complex.get_point(125) == []
- except IndexError:
- pass
- else:
- assert False
+ with pytest.raises(IndexError):
+ alpha_complex.get_point(len(point_list))
-def test_filtered_alpha():
+def test_infinite_alpha():
+ for precision in ['fast', 'safe', 'exact']:
+ _infinite_alpha(precision)
+
+def _filtered_alpha(precision):
point_list = [[0, 0], [1, 0], [0, 1], [1, 1]]
- filtered_alpha = AlphaComplex(points=point_list)
+ filtered_alpha = AlphaComplex(points=point_list, precision = precision)
simplex_tree = filtered_alpha.create_simplex_tree(max_alpha_square=0.25)
@@ -92,20 +94,11 @@ def test_filtered_alpha():
assert point_list[1] == filtered_alpha.get_point(1)
assert point_list[2] == filtered_alpha.get_point(2)
assert point_list[3] == filtered_alpha.get_point(3)
- try:
- filtered_alpha.get_point(4) == []
- except IndexError:
- pass
- else:
- assert False
- try:
- filtered_alpha.get_point(125) == []
- except IndexError:
- pass
- else:
- assert False
-
- assert simplex_tree.get_filtration() == [
+
+ with pytest.raises(IndexError):
+ filtered_alpha.get_point(len(point_list))
+
+ assert list(simplex_tree.get_filtration()) == [
([0], 0.0),
([1], 0.0),
([2], 0.0),
@@ -118,7 +111,11 @@ def test_filtered_alpha():
assert simplex_tree.get_star([0]) == [([0], 0.0), ([0, 1], 0.25), ([0, 2], 0.25)]
assert simplex_tree.get_cofaces([0], 1) == [([0, 1], 0.25), ([0, 2], 0.25)]
-def test_safe_alpha_persistence_comparison():
+def test_filtered_alpha():
+ for precision in ['fast', 'safe', 'exact']:
+ _filtered_alpha(precision)
+
+def _safe_alpha_persistence_comparison(precision):
#generate periodic signal
time = np.arange(0, 10, 1)
signal = [math.sin(x) for x in time]
@@ -130,10 +127,10 @@ def test_safe_alpha_persistence_comparison():
embedding2 = [[signal[i], delayed[i]] for i in range(len(time))]
#build alpha complex and simplex tree
- alpha_complex1 = AlphaComplex(points=embedding1)
+ alpha_complex1 = AlphaComplex(points=embedding1, precision = precision)
simplex_tree1 = alpha_complex1.create_simplex_tree()
- alpha_complex2 = AlphaComplex(points=embedding2)
+ alpha_complex2 = AlphaComplex(points=embedding2, precision = precision)
simplex_tree2 = alpha_complex2.create_simplex_tree()
diag1 = simplex_tree1.persistence()
@@ -142,3 +139,177 @@ def test_safe_alpha_persistence_comparison():
for (first_p, second_p) in zip_longest(diag1, diag2):
assert first_p[0] == pytest.approx(second_p[0])
assert first_p[1] == pytest.approx(second_p[1])
+
+
+def test_safe_alpha_persistence_comparison():
+ # Won't work for 'fast' version
+ _safe_alpha_persistence_comparison('safe')
+ _safe_alpha_persistence_comparison('exact')
+
+def _delaunay_complex(precision):
+ point_list = [[0, 0], [1, 0], [0, 1], [1, 1]]
+ filtered_alpha = AlphaComplex(points=point_list, precision = precision)
+
+ simplex_tree = filtered_alpha.create_simplex_tree(default_filtration_value = True)
+
+ assert simplex_tree.num_simplices() == 11
+ assert simplex_tree.num_vertices() == 4
+
+ assert point_list[0] == filtered_alpha.get_point(0)
+ assert point_list[1] == filtered_alpha.get_point(1)
+ assert point_list[2] == filtered_alpha.get_point(2)
+ assert point_list[3] == filtered_alpha.get_point(3)
+
+ with pytest.raises(IndexError):
+ filtered_alpha.get_point(4)
+ with pytest.raises(IndexError):
+ filtered_alpha.get_point(125)
+
+ for filtered_value in simplex_tree.get_filtration():
+ assert math.isnan(filtered_value[1])
+ for filtered_value in simplex_tree.get_star([0]):
+ assert math.isnan(filtered_value[1])
+ for filtered_value in simplex_tree.get_cofaces([0], 1):
+ assert math.isnan(filtered_value[1])
+
+def test_delaunay_complex():
+ for precision in ['fast', 'safe', 'exact']:
+ _delaunay_complex(precision)
+
+def _3d_points_on_a_plane(precision, default_filtration_value):
+ alpha = AlphaComplex(points = [[1.0, 1.0 , 0.0],
+ [7.0, 0.0 , 0.0],
+ [4.0, 6.0 , 0.0],
+ [9.0, 6.0 , 0.0],
+ [0.0, 14.0, 0.0],
+ [2.0, 19.0, 0.0],
+ [9.0, 17.0, 0.0]], precision = precision)
+
+ simplex_tree = alpha.create_simplex_tree(default_filtration_value = default_filtration_value)
+ assert simplex_tree.dimension() == 2
+ assert simplex_tree.num_vertices() == 7
+ assert simplex_tree.num_simplices() == 25
+
+def test_3d_points_on_a_plane():
+ for default_filtration_value in [True, False]:
+ for precision in ['fast', 'safe', 'exact']:
+ _3d_points_on_a_plane(precision, default_filtration_value)
+
+def _3d_tetrahedrons(precision):
+ points = 10*np.random.rand(10, 3)
+ alpha = AlphaComplex(points = points, precision = precision)
+ st_alpha = alpha.create_simplex_tree(default_filtration_value = False)
+ # New AlphaComplex for get_point to work
+ delaunay = AlphaComplex(points = points, precision = precision)
+ st_delaunay = delaunay.create_simplex_tree(default_filtration_value = True)
+
+ delaunay_tetra = []
+ for sk in st_delaunay.get_skeleton(4):
+ if len(sk[0]) == 4:
+ tetra = [delaunay.get_point(sk[0][0]),
+ delaunay.get_point(sk[0][1]),
+ delaunay.get_point(sk[0][2]),
+ delaunay.get_point(sk[0][3]) ]
+ delaunay_tetra.append(sorted(tetra, key=lambda tup: tup[0]))
+
+ alpha_tetra = []
+ for sk in st_alpha.get_skeleton(4):
+ if len(sk[0]) == 4:
+ tetra = [alpha.get_point(sk[0][0]),
+ alpha.get_point(sk[0][1]),
+ alpha.get_point(sk[0][2]),
+ alpha.get_point(sk[0][3]) ]
+ alpha_tetra.append(sorted(tetra, key=lambda tup: tup[0]))
+
+ # Check the tetrahedrons from one list are in the second one
+ assert len(alpha_tetra) == len(delaunay_tetra)
+ for tetra_from_del in delaunay_tetra:
+ assert tetra_from_del in alpha_tetra
+
+def test_3d_tetrahedrons():
+ for precision in ['fast', 'safe', 'exact']:
+ _3d_tetrahedrons(precision)
+
+def test_off_file_deprecation_warning():
+ off_file = open("alphacomplexdoc.off", "w")
+ off_file.write("OFF \n" \
+ "7 0 0 \n" \
+ "1.0 1.0 0.0\n" \
+ "7.0 0.0 0.0\n" \
+ "4.0 6.0 0.0\n" \
+ "9.0 6.0 0.0\n" \
+ "0.0 14.0 0.0\n" \
+ "2.0 19.0 0.0\n" \
+ "9.0 17.0 0.0\n" )
+ off_file.close()
+
+ with pytest.warns(DeprecationWarning):
+ alpha = AlphaComplex(off_file="alphacomplexdoc.off")
+
+def test_non_existing_off_file():
+ with pytest.warns(DeprecationWarning):
+ with pytest.raises(FileNotFoundError):
+ alpha = AlphaComplex(off_file="pouetpouettralala.toubiloubabdou")
+
+def test_inconsistency_points_and_weights():
+ points = [[1.0, 1.0 , 0.0],
+ [7.0, 0.0 , 0.0],
+ [4.0, 6.0 , 0.0],
+ [9.0, 6.0 , 0.0],
+ [0.0, 14.0, 0.0],
+ [2.0, 19.0, 0.0],
+ [9.0, 17.0, 0.0]]
+ with pytest.raises(ValueError):
+ # 7 points, 8 weights, on purpose
+ alpha = AlphaComplex(points = points,
+ weights = [1., 2., 3., 4., 5., 6., 7., 8.])
+
+ with pytest.raises(ValueError):
+ # 7 points, 6 weights, on purpose
+ alpha = AlphaComplex(points = points,
+ weights = [1., 2., 3., 4., 5., 6.])
+
+def _weighted_doc_example(precision):
+ stree = AlphaComplex(points=[[ 1., -1., -1.],
+ [-1., 1., -1.],
+ [-1., -1., 1.],
+ [ 1., 1., 1.],
+ [ 2., 2., 2.]],
+ weights = [4., 4., 4., 4., 1.],
+ precision = precision).create_simplex_tree()
+
+ assert stree.filtration([0, 1, 2, 3]) == pytest.approx(-1.)
+ assert stree.filtration([0, 1, 3, 4]) == pytest.approx(95.)
+ assert stree.filtration([0, 2, 3, 4]) == pytest.approx(95.)
+ assert stree.filtration([1, 2, 3, 4]) == pytest.approx(95.)
+
+def test_weighted_doc_example():
+ for precision in ['fast', 'safe', 'exact']:
+ _weighted_doc_example(precision)
+
+def test_float_relative_precision():
+ assert AlphaComplex.get_float_relative_precision() == 1e-5
+ # Must be > 0.
+ with pytest.raises(ValueError):
+ AlphaComplex.set_float_relative_precision(0.)
+ # Must be < 1.
+ with pytest.raises(ValueError):
+ AlphaComplex.set_float_relative_precision(1.)
+
+ points = [[1, 1], [7, 0], [4, 6], [9, 6], [0, 14], [2, 19], [9, 17]]
+ st = AlphaComplex(points=points).create_simplex_tree()
+ filtrations = list(st.get_filtration())
+
+ # Get a better precision
+ AlphaComplex.set_float_relative_precision(1e-15)
+ assert AlphaComplex.get_float_relative_precision() == 1e-15
+
+ st = AlphaComplex(points=points).create_simplex_tree()
+ filtrations_better_resolution = list(st.get_filtration())
+
+ assert len(filtrations) == len(filtrations_better_resolution)
+ for idx in range(len(filtrations)):
+ # check simplex is the same
+ assert filtrations[idx][0] == filtrations_better_resolution[idx][0]
+ # check filtration is about the same with a relative precision of the worst case
+ assert filtrations[idx][1] == pytest.approx(filtrations_better_resolution[idx][1], rel=1e-5)
diff --git a/src/python/test/test_betti_curve_representations.py b/src/python/test/test_betti_curve_representations.py
new file mode 100755
index 00000000..6a45da4d
--- /dev/null
+++ b/src/python/test/test_betti_curve_representations.py
@@ -0,0 +1,59 @@
+import numpy as np
+import scipy.interpolate
+import pytest
+
+from gudhi.representations.vector_methods import BettiCurve
+
+def test_betti_curve_is_irregular_betti_curve_followed_by_interpolation():
+ m = 10
+ n = 1000
+ pinf = 0.05
+ pzero = 0.05
+ res = 100
+
+ pds = []
+ for i in range(0, m):
+ pd = np.zeros((n, 2))
+ pd[:, 0] = np.random.uniform(0, 10, n)
+ pd[:, 1] = np.random.uniform(pd[:, 0], 10, n)
+ pd[np.random.uniform(0, 1, n) < pzero, 0] = 0
+ pd[np.random.uniform(0, 1, n) < pinf, 1] = np.inf
+ pds.append(pd)
+
+ bc = BettiCurve(resolution=None, predefined_grid=None)
+ bc.fit(pds)
+ bettis = bc.transform(pds)
+
+ bc2 = BettiCurve(resolution=None, predefined_grid=None)
+ bettis2 = bc2.fit_transform(pds)
+ assert((bc2.grid_ == bc.grid_).all())
+ assert((bettis2 == bettis).all())
+
+ for i in range(0, m):
+ grid = np.linspace(pds[i][np.isfinite(pds[i])].min(), pds[i][np.isfinite(pds[i])].max() + 1, res)
+ bc_gridded = BettiCurve(predefined_grid=grid)
+ bc_gridded.fit([])
+ bettis_gridded = bc_gridded(pds[i])
+
+ interp = scipy.interpolate.interp1d(bc.grid_, bettis[i, :], kind="previous", fill_value="extrapolate")
+ bettis_interp = np.array(interp(grid), dtype=int)
+ assert((bettis_interp == bettis_gridded).all())
+
+
+def test_empty_with_predefined_grid():
+ random_grid = np.sort(np.random.uniform(0, 1, 100))
+ bc = BettiCurve(predefined_grid=random_grid)
+ bettis = bc.fit_transform([])
+ assert((bc.grid_ == random_grid).all())
+ assert((bettis == 0).all())
+
+
+def test_empty():
+ bc = BettiCurve(resolution=None, predefined_grid=None)
+ bettis = bc.fit_transform([])
+ assert(bc.grid_ == [-np.inf])
+ assert((bettis == 0).all())
+
+def test_wrong_value_of_predefined_grid():
+ with pytest.raises(ValueError):
+ BettiCurve(predefined_grid=[1, 2, 3])
diff --git a/src/python/test/test_bottleneck_distance.py b/src/python/test/test_bottleneck_distance.py
index 70b2abad..07fcc9cc 100755
--- a/src/python/test/test_bottleneck_distance.py
+++ b/src/python/test/test_bottleneck_distance.py
@@ -9,6 +9,8 @@
"""
import gudhi
+import gudhi.hera
+import pytest
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2016 Inria"
@@ -19,5 +21,19 @@ def test_basic_bottleneck():
diag1 = [[2.7, 3.7], [9.6, 14.0], [34.2, 34.974], [3.0, float("Inf")]]
diag2 = [[2.8, 4.45], [9.5, 14.1], [3.2, float("Inf")]]
- assert gudhi.bottleneck_distance(diag1, diag2, 0.1) == 0.8081763781405569
assert gudhi.bottleneck_distance(diag1, diag2) == 0.75
+ assert gudhi.bottleneck_distance(diag1, diag2, 0.1) == pytest.approx(0.75, abs=0.1)
+ assert gudhi.hera.bottleneck_distance(diag1, diag2, 0) == 0.75
+ assert gudhi.hera.bottleneck_distance(diag1, diag2, 0.1) == pytest.approx(0.75, rel=0.1)
+
+ import numpy as np
+
+ # Translating both diagrams along the diagonal should not affect the distance,
+ # checks that negative numbers are not an issue
+ diag1 = np.array(diag1) - 100
+ diag2 = np.array(diag2) - 100
+
+ assert gudhi.bottleneck_distance(diag1, diag2) == 0.75
+ assert gudhi.bottleneck_distance(diag1, diag2, 0.1) == pytest.approx(0.75, abs=0.1)
+ assert gudhi.hera.bottleneck_distance(diag1, diag2, 0) == 0.75
+ assert gudhi.hera.bottleneck_distance(diag1, diag2, 0.1) == pytest.approx(0.75, rel=0.1)
diff --git a/src/python/test/test_cover_complex.py b/src/python/test/test_cover_complex.py
index 32bc5a26..260f6a5c 100755
--- a/src/python/test/test_cover_complex.py
+++ b/src/python/test/test_cover_complex.py
@@ -9,6 +9,7 @@
"""
from gudhi import CoverComplex
+import pytest
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2018 Inria"
@@ -24,7 +25,8 @@ def test_empty_constructor():
def test_non_existing_file_read():
# Try to open a non existing file
cover = CoverComplex()
- assert cover.read_point_cloud("pouetpouettralala.toubiloubabdou") == False
+ with pytest.raises(FileNotFoundError):
+ cover.read_point_cloud("pouetpouettralala.toubiloubabdou")
def test_files_creation():
diff --git a/src/python/test/test_cubical_complex.py b/src/python/test/test_cubical_complex.py
index 8c1b2600..29d559b3 100755
--- a/src/python/test/test_cubical_complex.py
+++ b/src/python/test/test_cubical_complex.py
@@ -10,6 +10,7 @@
from gudhi import CubicalComplex, PeriodicCubicalComplex
import numpy as np
+import pytest
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2016 Inria"
@@ -25,9 +26,8 @@ def test_empty_constructor():
def test_non_existing_perseus_file_constructor():
# Try to open a non existing file
- cub = CubicalComplex(perseus_file="pouetpouettralala.toubiloubabdou")
- assert cub.__is_defined() == False
- assert cub.__is_persistence_defined() == False
+ with pytest.raises(FileNotFoundError):
+ cub = CubicalComplex(perseus_file="pouetpouettralala.toubiloubabdou")
def test_dimension_or_perseus_file_constructor():
@@ -147,3 +147,55 @@ def test_connected_sublevel_sets():
periodic_dimensions = periodic_dimensions)
assert cub.persistence() == [(0, (2.0, float("inf")))]
assert cub.betti_numbers() == [1, 0, 0]
+
+def test_cubical_generators():
+ cub = CubicalComplex(top_dimensional_cells = [[0, 0, 0], [0, 1, 0], [0, 0, 0]])
+ cub.persistence()
+ g = cub.cofaces_of_persistence_pairs()
+ assert len(g[0]) == 2
+ assert len(g[1]) == 1
+ assert np.array_equal(g[0][0], np.empty(shape=[0,2]))
+ assert np.array_equal(g[0][1], np.array([[7, 4]]))
+ assert np.array_equal(g[1][0], np.array([8]))
+
+def test_cubical_cofaces_of_persistence_pairs_when_pd_has_no_paired_birth_and_death():
+ cubCpx = CubicalComplex(dimensions=[1,2], top_dimensional_cells=[0.0, 1.0])
+ Diag = cubCpx.persistence(homology_coeff_field=2, min_persistence=0)
+ pairs = cubCpx.cofaces_of_persistence_pairs()
+ assert pairs[0] == []
+ assert np.array_equal(pairs[1][0], np.array([0]))
+
+def test_periodic_cofaces_of_persistence_pairs_when_pd_has_no_paired_birth_and_death():
+ perCubCpx = PeriodicCubicalComplex(dimensions=[1,2], top_dimensional_cells=[0.0, 1.0],
+ periodic_dimensions=[True, True])
+ Diag = perCubCpx.persistence(homology_coeff_field=2, min_persistence=0)
+ pairs = perCubCpx.cofaces_of_persistence_pairs()
+ assert pairs[0] == []
+ assert np.array_equal(pairs[1][0], np.array([0]))
+ assert np.array_equal(pairs[1][1], np.array([0, 1]))
+ assert np.array_equal(pairs[1][2], np.array([1]))
+
+def test_cubical_persistence_intervals_in_dimension():
+ cub = CubicalComplex(
+ dimensions=[3, 3],
+ top_dimensional_cells=[1, 2, 3, 4, 5, 6, 7, 8, 9],
+ )
+ cub.compute_persistence()
+ H0 = cub.persistence_intervals_in_dimension(0)
+ assert np.array_equal(H0, np.array([[ 1., float("inf")]]))
+ assert cub.persistence_intervals_in_dimension(1).shape == (0, 2)
+
+def test_periodic_cubical_persistence_intervals_in_dimension():
+ cub = PeriodicCubicalComplex(
+ dimensions=[3, 3],
+ top_dimensional_cells=[1, 2, 3, 4, 5, 6, 7, 8, 9],
+ periodic_dimensions = [True, True]
+ )
+ cub.compute_persistence()
+ H0 = cub.persistence_intervals_in_dimension(0)
+ assert np.array_equal(H0, np.array([[ 1., float("inf")]]))
+ H1 = cub.persistence_intervals_in_dimension(1)
+ assert np.array_equal(H1, np.array([[ 3., float("inf")], [ 7., float("inf")]]))
+ H2 = cub.persistence_intervals_in_dimension(2)
+ assert np.array_equal(H2, np.array([[ 9., float("inf")]]))
+ assert cub.persistence_intervals_in_dimension(3).shape == (0, 2)
diff --git a/src/python/test/test_datasets_generators.py b/src/python/test/test_datasets_generators.py
new file mode 100755
index 00000000..91ec4a65
--- /dev/null
+++ b/src/python/test/test_datasets_generators.py
@@ -0,0 +1,39 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Hind Montassif
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.datasets.generators import points
+
+import pytest
+
+def test_sphere():
+ assert points.sphere(n_samples = 10, ambient_dim = 2, radius = 1., sample = 'random').shape == (10, 2)
+
+ with pytest.raises(ValueError):
+ points.sphere(n_samples = 10, ambient_dim = 2, radius = 1., sample = 'other')
+
+def _basic_torus(impl):
+ assert impl(n_samples = 64, dim = 3, sample = 'random').shape == (64, 6)
+ assert impl(n_samples = 64, dim = 3, sample = 'grid').shape == (64, 6)
+
+ assert impl(n_samples = 10, dim = 4, sample = 'random').shape == (10, 8)
+
+ # Here 1**dim < n_samples < 2**dim, the output shape is therefore (1, 2*dim) = (1, 8), where shape[0] is rounded down to the closest perfect 'dim'th power
+ assert impl(n_samples = 10, dim = 4, sample = 'grid').shape == (1, 8)
+
+ with pytest.raises(ValueError):
+ impl(n_samples = 10, dim = 4, sample = 'other')
+
+def test_torus():
+ for torus_impl in [points.torus, points.ctorus]:
+ _basic_torus(torus_impl)
+ # Check that the two versions (torus and ctorus) generate the same output
+ assert points.ctorus(n_samples = 64, dim = 3, sample = 'random').all() == points.torus(n_samples = 64, dim = 3, sample = 'random').all()
+ assert points.ctorus(n_samples = 64, dim = 3, sample = 'grid').all() == points.torus(n_samples = 64, dim = 3, sample = 'grid').all()
+ assert points.ctorus(n_samples = 10, dim = 3, sample = 'grid').all() == points.torus(n_samples = 10, dim = 3, sample = 'grid').all()
diff --git a/src/python/test/test_diff.py b/src/python/test/test_diff.py
new file mode 100644
index 00000000..dca001a9
--- /dev/null
+++ b/src/python/test/test_diff.py
@@ -0,0 +1,78 @@
+from gudhi.tensorflow import *
+import numpy as np
+import tensorflow as tf
+import gudhi as gd
+
+def test_rips_diff():
+
+ Xinit = np.array([[1.,1.],[2.,2.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ rl = RipsLayer(maximum_edge_length=2., homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = rl.call(X)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [X])
+ assert tf.norm(grads[0]-tf.constant([[-.5,-.5],[.5,.5]]),1) <= 1e-6
+
+def test_cubical_diff():
+
+ Xinit = np.array([[0.,2.,2.],[2.,2.,2.],[2.,2.,1.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ cl = CubicalLayer(homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = cl.call(X)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [X])
+ assert tf.norm(grads[0]-tf.constant([[0.,0.,0.],[0.,.5,0.],[0.,0.,-.5]]),1) <= 1e-6
+
+def test_nonsquare_cubical_diff():
+
+ Xinit = np.array([[-1.,1.,0.],[1.,1.,1.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ cl = CubicalLayer(homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = cl.call(X)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [X])
+ assert tf.norm(grads[0]-tf.constant([[0.,0.5,-0.5],[0.,0.,0.]]),1) <= 1e-6
+
+def test_st_diff():
+
+ st = gd.SimplexTree()
+ st.insert([0])
+ st.insert([1])
+ st.insert([2])
+ st.insert([3])
+ st.insert([4])
+ st.insert([5])
+ st.insert([6])
+ st.insert([7])
+ st.insert([8])
+ st.insert([9])
+ st.insert([10])
+ st.insert([0, 1])
+ st.insert([1, 2])
+ st.insert([2, 3])
+ st.insert([3, 4])
+ st.insert([4, 5])
+ st.insert([5, 6])
+ st.insert([6, 7])
+ st.insert([7, 8])
+ st.insert([8, 9])
+ st.insert([9, 10])
+
+ Finit = np.array([6.,4.,3.,4.,5.,4.,3.,2.,3.,4.,5.], dtype=np.float32)
+ F = tf.Variable(initial_value=Finit, trainable=True)
+ sl = LowerStarSimplexTreeLayer(simplextree=st, homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = sl.call(F)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [F])
+
+ assert tf.math.reduce_all(tf.math.equal(grads[0].indices, tf.constant([2,4])))
+ assert tf.math.reduce_all(tf.math.equal(grads[0].values, tf.constant([-1.,1.])))
+
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
new file mode 100755
index 00000000..b276f041
--- /dev/null
+++ b/src/python/test/test_dtm.py
@@ -0,0 +1,101 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Marc Glisse
+
+ Copyright (C) 2020 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.point_cloud.dtm import DistanceToMeasure, DTMDensity
+import numpy
+import pytest
+import torch
+import math
+import warnings
+
+
+def test_dtm_compare_euclidean():
+ pts = numpy.random.rand(1000, 4)
+ k = 6
+ dtm = DistanceToMeasure(k, implementation="ckdtree")
+ r0 = dtm.fit_transform(pts)
+ dtm = DistanceToMeasure(k, implementation="sklearn")
+ r1 = dtm.fit_transform(pts)
+ assert r1 == pytest.approx(r0)
+ dtm = DistanceToMeasure(k, implementation="sklearn", algorithm="brute")
+ r2 = dtm.fit_transform(pts)
+ assert r2 == pytest.approx(r0)
+ dtm = DistanceToMeasure(k, implementation="hnsw")
+ r3 = dtm.fit_transform(pts)
+ assert r3 == pytest.approx(r0, rel=0.1)
+ from scipy.spatial.distance import cdist
+
+ d = cdist(pts, pts)
+ dtm = DistanceToMeasure(k, metric="precomputed")
+ r4 = dtm.fit_transform(d)
+ assert r4 == pytest.approx(r0)
+ dtm = DistanceToMeasure(k, metric="precomputed", n_jobs=2)
+ r4b = dtm.fit_transform(d)
+ assert r4b == pytest.approx(r0)
+ dtm = DistanceToMeasure(k, implementation="keops")
+ r5 = dtm.fit_transform(pts)
+ assert r5 == pytest.approx(r0)
+ pts2 = torch.tensor(pts, requires_grad=True)
+ assert pts2.grad is None
+ dtm = DistanceToMeasure(k, implementation="keops", enable_autodiff=True)
+ r6 = dtm.fit_transform(pts2)
+ assert r6.detach().numpy() == pytest.approx(r0)
+ r6.sum().backward()
+ assert not torch.isnan(pts2.grad).any()
+ pts2 = torch.tensor(pts, requires_grad=True)
+ assert pts2.grad is None
+ dtm = DistanceToMeasure(k, implementation="ckdtree", enable_autodiff=True)
+ r7 = dtm.fit_transform(pts2)
+ assert r7.detach().numpy() == pytest.approx(r0)
+ r7.sum().backward()
+ assert not torch.isnan(pts2.grad).any()
+
+
+def test_dtm_precomputed():
+ dist = numpy.array([[1.0, 3, 8], [1, 5, 5], [0, 2, 3]])
+ dtm = DistanceToMeasure(2, q=1, metric="neighbors")
+ r = dtm.fit_transform(dist)
+ assert r == pytest.approx([2.0, 3, 1])
+
+ dist = numpy.array([[2.0, 2], [0, 1], [3, 4]])
+ dtm = DistanceToMeasure(2, q=2, metric="neighbors")
+ r = dtm.fit_transform(dist)
+ assert r == pytest.approx([2.0, 0.707, 3.5355], rel=0.01)
+
+
+def test_density_normalized():
+ sample = numpy.random.normal(0, 1, (1000000, 2))
+ queries = numpy.array([[0.0, 0.0], [-0.5, 0.7], [0.4, 1.7]])
+ expected = numpy.exp(-(queries ** 2).sum(-1) / 2) / (2 * math.pi)
+ estimated = DTMDensity(k=150, normalize=True).fit(sample).transform(queries)
+ assert estimated == pytest.approx(expected, rel=0.4)
+
+
+def test_density():
+ distances = [[0, 1, 10], [2, 0, 30], [1, 3, 5]]
+ density = DTMDensity(k=2, metric="neighbors", dim=1).fit_transform(distances)
+ expected = numpy.array([2.0, 1.0, 0.5])
+ assert density == pytest.approx(expected)
+ distances = [[0, 1], [2, 0], [1, 3]]
+ density = DTMDensity(metric="neighbors", dim=1).fit_transform(distances)
+ assert density == pytest.approx(expected)
+ density = DTMDensity(weights=[0.5, 0.5], metric="neighbors", dim=1).fit_transform(distances)
+ assert density == pytest.approx(expected)
+
+def test_dtm_overflow_warnings():
+ pts = numpy.array([[10., 100000000000000000000000000000.], [1000., 100000000000000000000000000.]])
+ impl_warn = ["keops", "hnsw"]
+ for impl in impl_warn:
+ with warnings.catch_warnings(record=True) as w:
+ dtm = DistanceToMeasure(2, implementation=impl)
+ r = dtm.fit_transform(pts)
+ assert len(w) == 1
+ assert issubclass(w[0].category, RuntimeWarning)
+ assert "Overflow" in str(w[0].message)
diff --git a/src/python/test/test_dtm_rips_complex.py b/src/python/test/test_dtm_rips_complex.py
new file mode 100644
index 00000000..e1c0ee44
--- /dev/null
+++ b/src/python/test/test_dtm_rips_complex.py
@@ -0,0 +1,32 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Yuichi Ike
+
+ Copyright (C) 2020 Inria, Copyright (C) 2020 FUjitsu Laboratories Ltd.
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.dtm_rips_complex import DTMRipsComplex
+from gudhi import RipsComplex
+import numpy as np
+from math import sqrt
+import pytest
+
+def test_dtm_rips_complex():
+ pts = np.array([[2.0, 2.0], [0.0, 1.0], [3.0, 4.0]])
+ dtm_rips = DTMRipsComplex(points=pts, k=2)
+ st = dtm_rips.create_simplex_tree(max_dimension=2)
+ st.persistence()
+ persistence_intervals0 = st.persistence_intervals_in_dimension(0)
+ assert persistence_intervals0 == pytest.approx(np.array([[3.16227766, 5.39834564],[3.16227766, 5.39834564], [3.16227766, float("inf")]]))
+
+def test_compatibility_with_rips():
+ distance_matrix = np.array([[0, 1, 1, sqrt(2)], [1, 0, sqrt(2), 1], [1, sqrt(2), 0, 1], [sqrt(2), 1, 1, 0]])
+ dtm_rips = DTMRipsComplex(distance_matrix=distance_matrix, max_filtration=42)
+ st = dtm_rips.create_simplex_tree(max_dimension=1)
+ rips_complex = RipsComplex(distance_matrix=distance_matrix, max_edge_length=42)
+ st_from_rips = rips_complex.create_simplex_tree(max_dimension=1)
+ assert list(st.get_filtration()) == list(st_from_rips.get_filtration())
+
diff --git a/src/python/test/test_euclidean_witness_complex.py b/src/python/test/test_euclidean_witness_complex.py
index c18d2484..f3664d39 100755
--- a/src/python/test/test_euclidean_witness_complex.py
+++ b/src/python/test/test_euclidean_witness_complex.py
@@ -40,7 +40,7 @@ def test_witness_complex():
assert landmarks[1] == euclidean_witness_complex.get_point(1)
assert landmarks[2] == euclidean_witness_complex.get_point(2)
- assert simplex_tree.get_filtration() == [
+ assert list(simplex_tree.get_filtration()) == [
([0], 0.0),
([1], 0.0),
([0, 1], 0.0),
@@ -78,13 +78,13 @@ def test_strong_witness_complex():
assert landmarks[1] == euclidean_strong_witness_complex.get_point(1)
assert landmarks[2] == euclidean_strong_witness_complex.get_point(2)
- assert simplex_tree.get_filtration() == [([0], 0.0), ([1], 0.0), ([2], 0.0)]
+ assert list(simplex_tree.get_filtration()) == [([0], 0.0), ([1], 0.0), ([2], 0.0)]
simplex_tree = euclidean_strong_witness_complex.create_simplex_tree(
max_alpha_square=100.0
)
- assert simplex_tree.get_filtration() == [
+ assert list(simplex_tree.get_filtration()) == [
([0], 0.0),
([1], 0.0),
([2], 0.0),
diff --git a/src/python/test/test_knn.py b/src/python/test/test_knn.py
new file mode 100755
index 00000000..a87ec212
--- /dev/null
+++ b/src/python/test/test_knn.py
@@ -0,0 +1,130 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Marc Glisse
+
+ Copyright (C) 2020 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.point_cloud.knn import KNearestNeighbors
+import numpy as np
+import pytest
+
+
+def test_knn_explicit():
+ base = np.array([[1.0, 1], [1, 2], [4, 2], [4, 3]])
+ query = np.array([[1.0, 1], [2, 2], [4, 4]])
+ knn = KNearestNeighbors(2, metric="manhattan", return_distance=True, return_index=True)
+ knn.fit(base)
+ r = knn.transform(query)
+ assert r[0] == pytest.approx(np.array([[0, 1], [1, 0], [3, 2]]))
+ assert r[1] == pytest.approx(np.array([[0.0, 1], [1, 2], [1, 2]]))
+
+ knn = KNearestNeighbors(2, metric="chebyshev", return_distance=True, return_index=False)
+ knn.fit(base)
+ r = knn.transform(query)
+ assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]]))
+ r = (
+ KNearestNeighbors(2, metric="chebyshev", return_distance=True, return_index=False, implementation="keops")
+ .fit(base)
+ .transform(query)
+ )
+ assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]]))
+ r = (
+ KNearestNeighbors(2, metric="chebyshev", return_distance=True, return_index=False, implementation="keops", enable_autodiff=True)
+ .fit(base)
+ .transform(query)
+ )
+ assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]]))
+
+ knn = KNearestNeighbors(2, metric="minkowski", p=3, return_distance=False, return_index=True)
+ knn.fit(base)
+ r = knn.transform(query)
+ assert np.array_equal(r, [[0, 1], [1, 0], [3, 2]])
+ r = (
+ KNearestNeighbors(2, metric="minkowski", p=3, return_distance=False, return_index=True, implementation="keops")
+ .fit(base)
+ .transform(query)
+ )
+ assert np.array_equal(r, [[0, 1], [1, 0], [3, 2]])
+
+ dist = np.array([[0.0, 3, 8], [1, 0, 5], [1, 2, 0]])
+ knn = KNearestNeighbors(2, metric="precomputed", return_index=True, return_distance=False)
+ r = knn.fit_transform(dist)
+ assert np.array_equal(r, [[0, 1], [1, 0], [2, 0]])
+ knn = KNearestNeighbors(2, metric="precomputed", return_index=True, return_distance=True, sort_results=True)
+ r = knn.fit_transform(dist)
+ assert np.array_equal(r[0], [[0, 1], [1, 0], [2, 0]])
+ assert np.array_equal(r[1], [[0, 3], [0, 1], [0, 1]])
+ # Second time in parallel
+ knn = KNearestNeighbors(2, metric="precomputed", return_index=True, return_distance=False, n_jobs=2, sort_results=True)
+ r = knn.fit_transform(dist)
+ assert np.array_equal(r, [[0, 1], [1, 0], [2, 0]])
+ knn = KNearestNeighbors(2, metric="precomputed", return_index=True, return_distance=True, n_jobs=2)
+ r = knn.fit_transform(dist)
+ assert np.array_equal(r[0], [[0, 1], [1, 0], [2, 0]])
+ assert np.array_equal(r[1], [[0, 3], [0, 1], [0, 1]])
+
+
+def test_knn_compare():
+ base = np.array([[1.0, 1], [1, 2], [4, 2], [4, 3]])
+ query = np.array([[1.0, 1], [2, 2], [4, 4]])
+ r0 = (
+ KNearestNeighbors(2, implementation="ckdtree", return_index=True, return_distance=False)
+ .fit(base)
+ .transform(query)
+ )
+ r1 = (
+ KNearestNeighbors(2, implementation="sklearn", return_index=True, return_distance=False)
+ .fit(base)
+ .transform(query)
+ )
+ r2 = (
+ KNearestNeighbors(2, implementation="hnsw", return_index=True, return_distance=False).fit(base).transform(query)
+ )
+ r3 = (
+ KNearestNeighbors(2, implementation="keops", return_index=True, return_distance=False)
+ .fit(base)
+ .transform(query)
+ )
+ assert np.array_equal(r0, r1) and np.array_equal(r0, r2) and np.array_equal(r0, r3)
+
+ r0 = (
+ KNearestNeighbors(2, implementation="ckdtree", return_index=True, return_distance=True)
+ .fit(base)
+ .transform(query)
+ )
+ r1 = (
+ KNearestNeighbors(2, implementation="sklearn", return_index=True, return_distance=True)
+ .fit(base)
+ .transform(query)
+ )
+ r2 = KNearestNeighbors(2, implementation="hnsw", return_index=True, return_distance=True).fit(base).transform(query)
+ r3 = (
+ KNearestNeighbors(2, implementation="keops", return_index=True, return_distance=True).fit(base).transform(query)
+ )
+ assert np.array_equal(r0[0], r1[0]) and np.array_equal(r0[0], r2[0]) and np.array_equal(r0[0], r3[0])
+ d0 = pytest.approx(r0[1])
+ assert r1[1] == d0 and r2[1] == d0 and r3[1] == d0
+
+
+def test_knn_nop():
+ # This doesn't look super useful...
+ p = np.array([[0.0]])
+ assert None is KNearestNeighbors(
+ k=1, return_index=False, return_distance=False, implementation="sklearn"
+ ).fit_transform(p)
+ assert None is KNearestNeighbors(
+ k=1, return_index=False, return_distance=False, implementation="ckdtree"
+ ).fit_transform(p)
+ assert None is KNearestNeighbors(
+ k=1, return_index=False, return_distance=False, implementation="hnsw", ef=5
+ ).fit_transform(p)
+ assert None is KNearestNeighbors(
+ k=1, return_index=False, return_distance=False, implementation="keops"
+ ).fit_transform(p)
+ assert None is KNearestNeighbors(
+ k=1, return_index=False, return_distance=False, metric="precomputed"
+ ).fit_transform(p)
diff --git a/src/python/test/test_off.py b/src/python/test/test_off.py
new file mode 100644
index 00000000..aea1941b
--- /dev/null
+++ b/src/python/test/test_off.py
@@ -0,0 +1,21 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Marc Glisse
+
+ Copyright (C) 2022 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+import gudhi as gd
+import numpy as np
+import pytest
+
+
+def test_off_rw():
+ for dim in range(2, 6):
+ X = np.random.rand(123, dim)
+ gd.write_points_to_off_file("rand.off", X)
+ Y = gd.read_points_from_off_file("rand.off")
+ assert Y == pytest.approx(X)
diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py
new file mode 100644
index 00000000..0e2ac3f8
--- /dev/null
+++ b/src/python/test/test_persistence_graphical_tools.py
@@ -0,0 +1,122 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+import gudhi as gd
+import numpy as np
+import matplotlib as plt
+import pytest
+import warnings
+
+
+def test_array_handler():
+ diags = np.array([[1, 2], [3, 4], [5, 6]], float)
+ arr_diags = gd.persistence_graphical_tools._array_handler(diags)
+ for idx in range(len(diags)):
+ assert arr_diags[idx][0] == 0
+ np.testing.assert_array_equal(arr_diags[idx][1], diags[idx])
+
+ diags = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]
+ arr_diags = gd.persistence_graphical_tools._array_handler(diags)
+ for idx in range(len(diags)):
+ assert arr_diags[idx][0] == 0
+ assert arr_diags[idx][1] == diags[idx]
+
+ diags = [(0, (1.0, 2.0)), (0, (3.0, 4.0)), (0, (5.0, 6.0))]
+ assert gd.persistence_graphical_tools._array_handler(diags) == diags
+
+
+def test_min_birth_max_death():
+ diags = [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ (0, (0.0, 0.118398)),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.117908)),
+ (0, (0.0, 0.112307)),
+ (0, (0.0, 0.107535)),
+ (0, (0.0, 0.106382)),
+ ]
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0.0, 1.0)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (0.0, 5.0)
+
+
+def test_limit_min_birth_max_death():
+ diags = [
+ (0, (2.0, float("inf"))),
+ (0, (2.0, float("inf"))),
+ ]
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2.0, 3.0)
+ assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (2.0, 6.0)
+
+
+def test_limit_to_max_intervals():
+ diags = [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ (0, (0.0, 0.118398)),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.117908)),
+ (0, (0.0, 0.112307)),
+ (0, (0.0, 0.107535)),
+ (0, (0.0, 0.106382)),
+ ]
+ # check no warnings if max_intervals equals to the diagrams number
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(
+ diags, 10, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ # check diagrams are not sorted
+ assert truncated_diags == diags
+
+ # check warning if max_intervals lower than the diagrams number
+ with pytest.warns(UserWarning) as record:
+ truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals(
+ diags, 5, key=lambda life_time: life_time[1][1] - life_time[1][0]
+ )
+ # check diagrams are truncated and sorted by life time
+ assert truncated_diags == [
+ (0, (0.0, float("inf"))),
+ (0, (0.0983494, float("inf"))),
+ (0, (0.118398, 1.0)),
+ (0, (0.0, 0.122545)),
+ (0, (0.0, 0.12047)),
+ ]
+ assert len(record) == 1
+
+
+def _limit_plot_persistence(function):
+ pplot = function(persistence=[])
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[], legend=True)
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[(0, float("inf"))])
+ assert isinstance(pplot, plt.axes.SubplotBase)
+ pplot = function(persistence=[(0, float("inf"))], legend=True)
+ assert isinstance(pplot, plt.axes.SubplotBase)
+
+
+def test_limit_plot_persistence():
+ for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]:
+ _limit_plot_persistence(function)
+
+
+def _non_existing_persistence_file(function):
+ with pytest.raises(FileNotFoundError):
+ function(persistence_file="pouetpouettralala.toubiloubabdou")
+
+
+def test_non_existing_persistence_file():
+ for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]:
+ _non_existing_persistence_file(function)
diff --git a/src/python/test/test_reader_utils.py b/src/python/test/test_reader_utils.py
index 90da6651..fdfddc4b 100755
--- a/src/python/test/test_reader_utils.py
+++ b/src/python/test/test_reader_utils.py
@@ -8,8 +8,9 @@
- YYYY/MM Author: Description of the modification
"""
-import gudhi
+import gudhi as gd
import numpy as np
+from pytest import raises
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2017 Inria"
@@ -18,7 +19,7 @@ __license__ = "MIT"
def test_non_existing_csv_file():
# Try to open a non existing file
- matrix = gudhi.read_lower_triangular_matrix_from_csv_file(
+ matrix = gd.read_lower_triangular_matrix_from_csv_file(
csv_file="pouetpouettralala.toubiloubabdou"
)
assert matrix == []
@@ -29,8 +30,8 @@ def test_full_square_distance_matrix_csv_file():
test_file = open("full_square_distance_matrix.csv", "w")
test_file.write("0;1;2;3;\n1;0;4;5;\n2;4;0;6;\n3;5;6;0;")
test_file.close()
- matrix = gudhi.read_lower_triangular_matrix_from_csv_file(
- csv_file="full_square_distance_matrix.csv"
+ matrix = gd.read_lower_triangular_matrix_from_csv_file(
+ csv_file="full_square_distance_matrix.csv", separator=";"
)
assert matrix == [[], [1.0], [2.0, 4.0], [3.0, 5.0, 6.0]]
@@ -40,7 +41,7 @@ def test_lower_triangular_distance_matrix_csv_file():
test_file = open("lower_triangular_distance_matrix.csv", "w")
test_file.write("\n1,\n2,3,\n4,5,6,\n7,8,9,10,")
test_file.close()
- matrix = gudhi.read_lower_triangular_matrix_from_csv_file(
+ matrix = gd.read_lower_triangular_matrix_from_csv_file(
csv_file="lower_triangular_distance_matrix.csv", separator=","
)
assert matrix == [[], [1.0], [2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]]
@@ -48,11 +49,11 @@ def test_lower_triangular_distance_matrix_csv_file():
def test_non_existing_persistence_file():
# Try to open a non existing file
- persistence = gudhi.read_persistence_intervals_grouped_by_dimension(
+ persistence = gd.read_persistence_intervals_grouped_by_dimension(
persistence_file="pouetpouettralala.toubiloubabdou"
)
assert persistence == []
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="pouetpouettralala.toubiloubabdou", only_this_dim=1
)
np.testing.assert_array_equal(persistence, [])
@@ -65,21 +66,21 @@ def test_read_persistence_intervals_without_dimension():
"# Simple persistence diagram without dimension\n2.7 3.7\n9.6 14.\n34.2 34.974\n3. inf"
)
test_file.close()
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_without_dimension.pers"
)
np.testing.assert_array_equal(
persistence, [(2.7, 3.7), (9.6, 14.0), (34.2, 34.974), (3.0, float("Inf"))]
)
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_without_dimension.pers", only_this_dim=0
)
np.testing.assert_array_equal(persistence, [])
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_without_dimension.pers", only_this_dim=1
)
np.testing.assert_array_equal(persistence, [])
- persistence = gudhi.read_persistence_intervals_grouped_by_dimension(
+ persistence = gd.read_persistence_intervals_grouped_by_dimension(
persistence_file="persistence_intervals_without_dimension.pers"
)
assert persistence == {
@@ -94,29 +95,29 @@ def test_read_persistence_intervals_with_dimension():
"# Simple persistence diagram with dimension\n0 2.7 3.7\n1 9.6 14.\n3 34.2 34.974\n1 3. inf"
)
test_file.close()
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers"
)
np.testing.assert_array_equal(
persistence, [(2.7, 3.7), (9.6, 14.0), (34.2, 34.974), (3.0, float("Inf"))]
)
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers", only_this_dim=0
)
np.testing.assert_array_equal(persistence, [(2.7, 3.7)])
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers", only_this_dim=1
)
np.testing.assert_array_equal(persistence, [(9.6, 14.0), (3.0, float("Inf"))])
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers", only_this_dim=2
)
np.testing.assert_array_equal(persistence, [])
- persistence = gudhi.read_persistence_intervals_in_dimension(
+ persistence = gd.read_persistence_intervals_in_dimension(
persistence_file="persistence_intervals_with_dimension.pers", only_this_dim=3
)
np.testing.assert_array_equal(persistence, [(34.2, 34.974)])
- persistence = gudhi.read_persistence_intervals_grouped_by_dimension(
+ persistence = gd.read_persistence_intervals_grouped_by_dimension(
persistence_file="persistence_intervals_with_dimension.pers"
)
assert persistence == {
diff --git a/src/python/test/test_remote_datasets.py b/src/python/test/test_remote_datasets.py
new file mode 100644
index 00000000..e5d2de82
--- /dev/null
+++ b/src/python/test/test_remote_datasets.py
@@ -0,0 +1,87 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Hind Montassif
+#
+# Copyright (C) 2021 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+from gudhi.datasets import remote
+
+import shutil
+import io
+import sys
+import pytest
+
+from os.path import isdir, expanduser, exists
+from os import remove, environ
+
+def test_data_home():
+ # Test _get_data_home and clear_data_home on new empty folder
+ empty_data_home = remote._get_data_home(data_home="empty_folder_for_test")
+ assert isdir(empty_data_home)
+
+ remote.clear_data_home(data_home=empty_data_home)
+ assert not isdir(empty_data_home)
+
+def test_fetch_remote():
+ # Test fetch with a wrong checksum
+ with pytest.raises(OSError):
+ remote._fetch_remote("https://raw.githubusercontent.com/GUDHI/gudhi-data/main/points/spiral_2d/spiral_2d.npy", "tmp_spiral_2d.npy", file_checksum = 'XXXXXXXXXX')
+ assert not exists("tmp_spiral_2d.npy")
+
+def _get_bunny_license_print(accept_license = False):
+ capturedOutput = io.StringIO()
+ # Redirect stdout
+ sys.stdout = capturedOutput
+
+ bunny_arr = remote.fetch_bunny("./tmp_for_test/bunny.npy", accept_license)
+ assert bunny_arr.shape == (35947, 3)
+ del bunny_arr
+ remove("./tmp_for_test/bunny.npy")
+
+ # Reset redirect
+ sys.stdout = sys.__stdout__
+ return capturedOutput
+
+def test_print_bunny_license():
+ # Test not printing bunny.npy LICENSE when accept_license = True
+ assert "" == _get_bunny_license_print(accept_license = True).getvalue()
+ # Test printing bunny.LICENSE file when fetching bunny.npy with accept_license = False (default)
+ with open("./tmp_for_test/bunny.LICENSE") as f:
+ assert f.read().rstrip("\n") == _get_bunny_license_print().getvalue().rstrip("\n")
+ shutil.rmtree("./tmp_for_test")
+
+def test_fetch_remote_datasets_wrapped():
+ # Test fetch_spiral_2d and fetch_bunny wrapping functions with data directory different from default (twice, to test case of already fetched files)
+ # Default case is not tested because it would fail in case the user sets the 'GUDHI_DATA' environment variable locally
+ for i in range(2):
+ spiral_2d_arr = remote.fetch_spiral_2d("./another_fetch_folder_for_test/spiral_2d.npy")
+ assert spiral_2d_arr.shape == (114562, 2)
+
+ bunny_arr = remote.fetch_bunny("./another_fetch_folder_for_test/bunny.npy")
+ assert bunny_arr.shape == (35947, 3)
+
+ # Check that the directory was created
+ assert isdir("./another_fetch_folder_for_test")
+ # Check downloaded files
+ assert exists("./another_fetch_folder_for_test/spiral_2d.npy")
+ assert exists("./another_fetch_folder_for_test/bunny.npy")
+ assert exists("./another_fetch_folder_for_test/bunny.LICENSE")
+
+ # Remove test folders
+ del spiral_2d_arr
+ del bunny_arr
+ shutil.rmtree("./another_fetch_folder_for_test")
+
+def test_gudhi_data_env():
+ # Set environment variable "GUDHI_DATA"
+ environ["GUDHI_DATA"] = "./test_folder_from_env_var"
+ bunny_arr = remote.fetch_bunny()
+ assert bunny_arr.shape == (35947, 3)
+ assert exists("./test_folder_from_env_var/points/bunny/bunny.npy")
+ assert exists("./test_folder_from_env_var/points/bunny/bunny.LICENSE")
+ # Remove test folder
+ del bunny_arr
+ shutil.rmtree("./test_folder_from_env_var")
diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py
index dba7f952..f4ffbdc1 100755
--- a/src/python/test/test_representations.py
+++ b/src/python/test/test_representations.py
@@ -1,12 +1,269 @@
import os
import sys
import matplotlib.pyplot as plt
+import numpy as np
+import pytest
+import random
+
+from sklearn.cluster import KMeans
+
+# Vectorization
+from gudhi.representations import (Landscape, Silhouette, BettiCurve, ComplexPolynomial,\
+ TopologicalVector, PersistenceImage, Entropy)
+
+# Preprocessing
+from gudhi.representations import (BirthPersistenceTransform, Clamping, DiagramScaler, Padding, ProminentPoints, \
+ DiagramSelector)
+
+# Kernel
+from gudhi.representations import (PersistenceWeightedGaussianKernel, \
+ PersistenceScaleSpaceKernel, SlicedWassersteinDistance,\
+ SlicedWassersteinKernel, PersistenceFisherKernel, WassersteinDistance)
+
def test_representations_examples():
# Disable graphics for testing purposes
- plt.show = lambda:None
+ plt.show = lambda: None
here = os.path.dirname(os.path.realpath(__file__))
sys.path.append(here + "/../example")
import diagram_vectorizations_distances_kernels
return None
+
+
+from gudhi.representations.vector_methods import Atol
+from gudhi.representations.metrics import *
+from gudhi.representations.kernel_methods import *
+
+
+def _n_diags(n):
+ l = []
+ for _ in range(n):
+ a = np.random.rand(50, 2)
+ a[:, 1] += a[:, 0] # So that y >= x
+ l.append(a)
+ return l
+
+
+def test_multiple():
+ l1 = _n_diags(9)
+ l2 = _n_diags(11)
+ l1b = l1.copy()
+ d1 = pairwise_persistence_diagram_distances(l1, e=0.00001, n_jobs=4)
+ d2 = BottleneckDistance(epsilon=0.00001).fit_transform(l1)
+ d3 = pairwise_persistence_diagram_distances(l1, l1b, e=0.00001, n_jobs=4)
+ assert d1 == pytest.approx(d2)
+ assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal)
+ d1 = pairwise_persistence_diagram_distances(l1, l2, metric="wasserstein", order=2, internal_p=2)
+ d2 = WassersteinDistance(order=2, internal_p=2, n_jobs=4).fit(l2).transform(l1)
+ print(d1.shape, d2.shape)
+ assert d1 == pytest.approx(d2, rel=0.02)
+
+
+# Test sorted values as points order can be inverted, and sorted test is not documentation-friendly
+# Note the test below must be up to date with the Atol class documentation
+def test_atol_doc():
+ a = np.array([[1, 2, 4], [1, 4, 0], [1, 0, 4]])
+ b = np.array([[4, 2, 0], [4, 4, 0], [4, 0, 2]])
+ c = np.array([[3, 2, -1], [1, 2, -1]])
+
+ atol_vectoriser = Atol(quantiser=KMeans(n_clusters=2, random_state=202006))
+ # Atol will do
+ # X = np.concatenate([a,b,c])
+ # kmeans = KMeans(n_clusters=2, random_state=202006).fit(X)
+ # kmeans.labels_ will be : array([1, 0, 1, 0, 0, 1, 0, 0])
+ first_cluster = np.asarray([a[0], a[2], b[2]])
+ second_cluster = np.asarray([a[1], b[0], b[2], c[0], c[1]])
+
+ # Check the center of the first_cluster and second_cluster are in Atol centers
+ centers = atol_vectoriser.fit(X=[a, b, c]).centers
+ np.isclose(centers, first_cluster.mean(axis=0)).all(1).any()
+ np.isclose(centers, second_cluster.mean(axis=0)).all(1).any()
+
+ vectorization = atol_vectoriser.transform(X=[a, b, c])
+ assert np.allclose(vectorization[0], atol_vectoriser(a))
+ assert np.allclose(vectorization[1], atol_vectoriser(b))
+ assert np.allclose(vectorization[2], atol_vectoriser(c))
+
+
+def test_dummy_atol():
+ a = np.array([[1, 2, 4], [1, 4, 0], [1, 0, 4]])
+ b = np.array([[4, 2, 0], [4, 4, 0], [4, 0, 2]])
+ c = np.array([[3, 2, -1], [1, 2, -1]])
+
+ for weighting_method in ["cloud", "iidproba"]:
+ for contrast in ["gaussian", "laplacian", "indicator"]:
+ atol_vectoriser = Atol(
+ quantiser=KMeans(n_clusters=1, random_state=202006),
+ weighting_method=weighting_method,
+ contrast=contrast,
+ )
+ atol_vectoriser.fit([a, b, c])
+ atol_vectoriser(a)
+ atol_vectoriser.transform(X=[a, b, c])
+
+
+from gudhi.representations.vector_methods import BettiCurve
+
+def test_infinity():
+ a = np.array([[1.0, 8.0], [2.0, np.inf], [3.0, 4.0]])
+ c = BettiCurve(20, [0.0, 10.0])(a)
+ assert c[1] == 0
+ assert c[7] == 3
+ assert c[9] == 2
+
+def test_preprocessing_empty_diagrams():
+ empty_diag = np.empty(shape = [0, 2])
+ assert not np.any(BirthPersistenceTransform()(empty_diag))
+ assert not np.any(Clamping().fit_transform(empty_diag))
+ assert not np.any(DiagramScaler()(empty_diag))
+ assert not np.any(Padding()(empty_diag))
+ assert not np.any(ProminentPoints()(empty_diag))
+ assert not np.any(DiagramSelector()(empty_diag))
+
+def pow(n):
+ return lambda x: np.power(x[1]-x[0],n)
+
+def test_vectorization_empty_diagrams():
+ empty_diag = np.empty(shape = [0, 2])
+ random_resolution = random.randint(50,100)*10 # between 500 and 1000
+ print("resolution = ", random_resolution)
+ lsc = Landscape(resolution=random_resolution)(empty_diag)
+ assert not np.any(lsc)
+ assert lsc.shape[0]%random_resolution == 0
+ slt = Silhouette(resolution=random_resolution, weight=pow(2))(empty_diag)
+ assert not np.any(slt)
+ assert slt.shape[0] == random_resolution
+ btc = BettiCurve(resolution=random_resolution)(empty_diag)
+ assert not np.any(btc)
+ assert btc.shape[0] == random_resolution
+ cpp = ComplexPolynomial(threshold=random_resolution, polynomial_type="T")(empty_diag)
+ assert not np.any(cpp)
+ assert cpp.shape[0] == random_resolution
+ tpv = TopologicalVector(threshold=random_resolution)(empty_diag)
+ assert tpv.shape[0] == random_resolution
+ assert not np.any(tpv)
+ prmg = PersistenceImage(resolution=[random_resolution,random_resolution])(empty_diag)
+ assert not np.any(prmg)
+ assert prmg.shape[0] == random_resolution * random_resolution
+ sce = Entropy(mode="scalar", resolution=random_resolution)(empty_diag)
+ assert not np.any(sce)
+ assert sce.shape[0] == 1
+ scv = Entropy(mode="vector", normalized=False, resolution=random_resolution)(empty_diag)
+ assert not np.any(scv)
+ assert scv.shape[0] == random_resolution
+
+def test_entropy_miscalculation():
+ diag_ex = np.array([[0.0,1.0], [0.0,1.0], [0.0,2.0]])
+ def pe(pd):
+ l = pd[:,1] - pd[:,0]
+ l = l/sum(l)
+ return -np.dot(l, np.log(l))
+ sce = Entropy(mode="scalar")
+ assert [[pe(diag_ex)]] == sce.fit_transform([diag_ex])
+ sce = Entropy(mode="vector", resolution=4, normalized=False, keep_endpoints=True)
+ pef = [-1/4*np.log(1/4)-1/4*np.log(1/4)-1/2*np.log(1/2),
+ -1/4*np.log(1/4)-1/4*np.log(1/4)-1/2*np.log(1/2),
+ -1/2*np.log(1/2),
+ 0.0]
+ assert all(([pef] == sce.fit_transform([diag_ex]))[0])
+ sce = Entropy(mode="vector", resolution=4, normalized=True)
+ pefN = (sce.fit_transform([diag_ex]))[0]
+ area = np.linalg.norm(pefN, ord=1)
+ assert area==pytest.approx(1)
+
+def test_kernel_empty_diagrams():
+ empty_diag = np.empty(shape = [0, 2])
+ assert SlicedWassersteinDistance(num_directions=100)(empty_diag, empty_diag) == 0.
+ assert SlicedWassersteinKernel(num_directions=100, bandwidth=1.)(empty_diag, empty_diag) == 1.
+ assert WassersteinDistance(mode="hera", delta=0.0001)(empty_diag, empty_diag) == 0.
+ assert WassersteinDistance(mode="pot")(empty_diag, empty_diag) == 0.
+ assert BottleneckDistance(epsilon=.001)(empty_diag, empty_diag) == 0.
+ assert BottleneckDistance()(empty_diag, empty_diag) == 0.
+# PersistenceWeightedGaussianKernel(bandwidth=1., kernel_approx=None, weight=arctan(1.,1.))(empty_diag, empty_diag)
+# PersistenceWeightedGaussianKernel(kernel_approx=RBFSampler(gamma=1./2, n_components=100000).fit(np.ones([1,2])), weight=arctan(1.,1.))(empty_diag, empty_diag)
+# PersistenceScaleSpaceKernel(bandwidth=1.)(empty_diag, empty_diag)
+# PersistenceScaleSpaceKernel(kernel_approx=RBFSampler(gamma=1./2, n_components=100000).fit(np.ones([1,2])))(empty_diag, empty_diag)
+# PersistenceFisherKernel(bandwidth_fisher=1., bandwidth=1.)(empty_diag, empty_diag)
+# PersistenceFisherKernel(bandwidth_fisher=1., bandwidth=1., kernel_approx=RBFSampler(gamma=1./2, n_components=100000).fit(np.ones([1,2])))(empty_diag, empty_diag)
+
+
+def test_silhouette_permutation_invariance():
+ dgm = _n_diags(1)[0]
+ dgm_permuted = dgm[np.random.permutation(dgm.shape[0]).astype(int)]
+ random_resolution = random.randint(50, 100) * 10
+ slt = Silhouette(resolution=random_resolution, weight=pow(2))
+
+ assert np.all(np.isclose(slt(dgm), slt(dgm_permuted)))
+
+
+def test_silhouette_multiplication_invariance():
+ dgm = _n_diags(1)[0]
+ n_repetitions = np.random.randint(2, high=10)
+ dgm_augmented = np.repeat(dgm, repeats=n_repetitions, axis=0)
+
+ random_resolution = random.randint(50, 100) * 10
+ slt = Silhouette(resolution=random_resolution, weight=pow(2))
+ assert np.all(np.isclose(slt(dgm), slt(dgm_augmented)))
+
+
+def test_silhouette_numeric():
+ dgm = np.array([[2., 3.], [5., 6.]])
+ slt = Silhouette(resolution=9, weight=pow(1), sample_range=[2., 6.])
+ #slt.fit([dgm])
+ # x_values = array([2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6.])
+
+ expected_silhouette = np.array([0., 0.5, 0., 0., 0., 0., 0., 0.5, 0.])/np.sqrt(2)
+ output_silhouette = slt(dgm)
+ assert np.all(np.isclose(output_silhouette, expected_silhouette))
+
+
+def test_landscape_small_persistence_invariance():
+ dgm = np.array([[2., 6.], [2., 5.], [3., 7.]])
+ small_persistence_pts = np.random.rand(10, 2)
+ small_persistence_pts[:, 1] += small_persistence_pts[:, 0]
+ small_persistence_pts += np.min(dgm)
+ dgm_augmented = np.concatenate([dgm, small_persistence_pts], axis=0)
+
+ lds = Landscape(num_landscapes=2, resolution=5)
+ lds_dgm, lds_dgm_augmented = lds(dgm), lds(dgm_augmented)
+
+ assert np.all(np.isclose(lds_dgm, lds_dgm_augmented))
+
+
+def test_landscape_numeric():
+ dgm = np.array([[2., 6.], [3., 5.]])
+ lds_ref = np.array([
+ 0., 0.5, 1., 1.5, 2., 1.5, 1., 0.5, 0., # tent of [2, 6]
+ 0., 0., 0., 0.5, 1., 0.5, 0., 0., 0.,
+ 0., 0., 0., 0., 0., 0., 0., 0., 0.,
+ 0., 0., 0., 0., 0., 0., 0., 0., 0.,
+ ])
+ lds_ref *= np.sqrt(2)
+ lds = Landscape(num_landscapes=4, resolution=9, sample_range=[2., 6.])
+ lds_dgm = lds(dgm)
+ assert np.all(np.isclose(lds_dgm, lds_ref))
+
+
+def test_landscape_nan_range():
+ dgm = np.array([[2., 6.], [3., 5.]])
+ lds = Landscape(num_landscapes=2, resolution=9, sample_range=[np.nan, 6.])
+ lds_dgm = lds(dgm)
+ assert (lds.sample_range_fixed[0] == 2) & (lds.sample_range_fixed[1] == 6)
+ assert lds.new_resolution == 10
+
+def test_endpoints():
+ diags = [ np.array([[2., 3.]]) ]
+ for vec in [ Landscape(), Silhouette(), BettiCurve(), Entropy(mode="vector") ]:
+ vec.fit(diags)
+ assert vec.grid_[0] > 2 and vec.grid_[-1] < 3
+ for vec in [ Landscape(keep_endpoints=True), Silhouette(keep_endpoints=True), BettiCurve(keep_endpoints=True), Entropy(mode="vector", keep_endpoints=True)]:
+ vec.fit(diags)
+ assert vec.grid_[0] == 2 and vec.grid_[-1] == 3
+ vec = BettiCurve(resolution=None)
+ vec.fit(diags)
+ assert np.equal(vec.grid_, [-np.inf, 2., 3.]).all()
+
+def test_get_params():
+ for vec in [ Landscape(), Silhouette(), BettiCurve(), Entropy(mode="vector") ]:
+ vec.get_params()
diff --git a/src/python/test/test_representations_preprocessing.py b/src/python/test/test_representations_preprocessing.py
new file mode 100644
index 00000000..838cf30c
--- /dev/null
+++ b/src/python/test/test_representations_preprocessing.py
@@ -0,0 +1,39 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.representations.preprocessing import DimensionSelector
+import numpy as np
+import pytest
+
+H0_0 = np.array([0.0, 0.0])
+H1_0 = np.array([1.0, 0.0])
+H0_1 = np.array([0.0, 1.0])
+H1_1 = np.array([1.0, 1.0])
+H0_2 = np.array([0.0, 2.0])
+H1_2 = np.array([1.0, 2.0])
+
+
+def test_dimension_selector():
+ X = [[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]]
+ ds = DimensionSelector(index=0)
+ h0 = ds.fit_transform(X)
+ np.testing.assert_array_equal(h0[0], H0_0)
+ np.testing.assert_array_equal(h0[1], H0_1)
+ np.testing.assert_array_equal(h0[2], H0_2)
+
+ ds = DimensionSelector(index=1)
+ h1 = ds.fit_transform(X)
+ np.testing.assert_array_equal(h1[0], H1_0)
+ np.testing.assert_array_equal(h1[1], H1_1)
+ np.testing.assert_array_equal(h1[2], H1_2)
+
+ ds = DimensionSelector(index=2)
+ with pytest.raises(IndexError):
+ h2 = ds.fit_transform([[H0_0, H1_0], [H0_1, H1_1], [H0_2, H1_2]])
diff --git a/src/python/test/test_rips_complex.py b/src/python/test/test_rips_complex.py
index b02a68e1..a2f43a1b 100755
--- a/src/python/test/test_rips_complex.py
+++ b/src/python/test/test_rips_complex.py
@@ -32,7 +32,7 @@ def test_rips_from_points():
assert simplex_tree.num_simplices() == 10
assert simplex_tree.num_vertices() == 4
- assert simplex_tree.get_filtration() == [
+ assert list(simplex_tree.get_filtration()) == [
([0], 0.0),
([1], 0.0),
([2], 0.0),
@@ -44,6 +44,7 @@ def test_rips_from_points():
([1, 2], 1.4142135623730951),
([0, 3], 1.4142135623730951),
]
+
assert simplex_tree.get_star([0]) == [
([0], 0.0),
([0, 1], 1.0),
@@ -95,7 +96,7 @@ def test_rips_from_distance_matrix():
assert simplex_tree.num_simplices() == 10
assert simplex_tree.num_vertices() == 4
- assert simplex_tree.get_filtration() == [
+ assert list(simplex_tree.get_filtration()) == [
([0], 0.0),
([1], 0.0),
([2], 0.0),
@@ -107,6 +108,7 @@ def test_rips_from_distance_matrix():
([1, 2], 1.4142135623730951),
([0, 3], 1.4142135623730951),
]
+
assert simplex_tree.get_star([0]) == [
([0], 0.0),
([0, 1], 1.0),
@@ -131,3 +133,24 @@ def test_filtered_rips_from_distance_matrix():
assert simplex_tree.num_simplices() == 8
assert simplex_tree.num_vertices() == 4
+
+
+def test_sparse_with_multiplicity():
+ points = [
+ [3, 4],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [0.1, 2],
+ [3, 4.1],
+ ]
+ rips = RipsComplex(points=points, sparse=0.01)
+ simplex_tree = rips.create_simplex_tree(max_dimension=2)
+ assert simplex_tree.num_simplices() == 7
+ diag = simplex_tree.persistence()
diff --git a/src/python/test/test_simplex_generators.py b/src/python/test/test_simplex_generators.py
new file mode 100755
index 00000000..c567d4c1
--- /dev/null
+++ b/src/python/test/test_simplex_generators.py
@@ -0,0 +1,64 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Marc Glisse
+
+ Copyright (C) 2020 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+import gudhi
+import numpy as np
+
+
+def test_flag_generators():
+ pts = np.array([[0, 0], [0, 1.01], [1, 0], [1.02, 1.03], [100, 0], [100, 3.01], [103, 0], [103.02, 3.03]])
+ r = gudhi.RipsComplex(points=pts, max_edge_length=4)
+ st = r.create_simplex_tree(max_dimension=50)
+ st.persistence()
+ g = st.flag_persistence_generators()
+ assert np.array_equal(g[0], [[2, 2, 0], [1, 1, 0], [3, 3, 1], [6, 6, 4], [5, 5, 4], [7, 7, 5]])
+ assert len(g[1]) == 1
+ assert np.array_equal(g[1][0], [[3, 2, 2, 1]])
+ assert np.array_equal(g[2], [0, 4])
+ assert len(g[3]) == 1
+ assert np.array_equal(g[3][0], [[7, 6]])
+ # Compare trivial cases (where the simplex is the generator) with persistence_pairs.
+ # This still makes assumptions on the order of vertices in a simplex and could be more robust.
+ pairs = st.persistence_pairs()
+ assert {tuple(i) for i in g[0]} == {(i[0][0],) + tuple(i[1]) for i in pairs if len(i[0]) == 1 and len(i[1]) != 0}
+ assert {(i[0], i[1]) for i in g[1][0]} == {tuple(i[0]) for i in pairs if len(i[0]) == 2 and len(i[1]) != 0}
+ assert set(g[2]) == {i[0][0] for i in pairs if len(i[0]) == 1 and len(i[1]) == 0}
+ assert {(i[0], i[1]) for i in g[3][0]} == {tuple(i[0]) for i in pairs if len(i[0]) == 2 and len(i[1]) == 0}
+
+
+def test_lower_star_generators():
+ st = gudhi.SimplexTree()
+ st.insert([0, 1, 2], -10)
+ st.insert([0, 3], -10)
+ st.insert([1, 3], -10)
+ st.assign_filtration([2], -1)
+ st.assign_filtration([3], 0)
+ st.assign_filtration([0], 1)
+ st.assign_filtration([1], 2)
+ st.make_filtration_non_decreasing()
+ st.persistence(min_persistence=-1)
+ g = st.lower_star_persistence_generators()
+ assert len(g[0]) == 2
+ assert np.array_equal(g[0][0], [[0, 0], [3, 0], [1, 1]])
+ assert np.array_equal(g[0][1], [[1, 1]])
+ assert len(g[1]) == 2
+ assert np.array_equal(g[1][0], [2])
+ assert np.array_equal(g[1][1], [1])
+
+
+def test_empty():
+ st = gudhi.SimplexTree()
+ st.persistence()
+ assert st.lower_star_persistence_generators() == ([], [])
+ g = st.flag_persistence_generators()
+ assert np.array_equal(g[0], np.empty((0, 3)))
+ assert g[1] == []
+ assert np.array_equal(g[2], [])
+ assert g[3] == []
diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py
index 1822c43b..2ccbfbf5 100755
--- a/src/python/test/test_simplex_tree.py
+++ b/src/python/test/test_simplex_tree.py
@@ -9,6 +9,8 @@
"""
from gudhi import SimplexTree
+import numpy as np
+import pytest
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2016 Inria"
@@ -45,7 +47,6 @@ def test_insertion():
assert st.find([2, 3]) == False
# filtration test
- st.initialize_filtration()
assert st.filtration([0, 1, 2]) == 4.0
assert st.filtration([0, 2]) == 4.0
assert st.filtration([1, 2]) == 4.0
@@ -55,7 +56,7 @@ def test_insertion():
assert st.filtration([1]) == 0.0
# skeleton test
- assert st.get_skeleton(2) == [
+ assert list(st.get_skeleton(2)) == [
([0, 1, 2], 4.0),
([0, 1], 0.0),
([0, 2], 4.0),
@@ -64,7 +65,7 @@ def test_insertion():
([1], 0.0),
([2], 4.0),
]
- assert st.get_skeleton(1) == [
+ assert list(st.get_skeleton(1)) == [
([0, 1], 0.0),
([0, 2], 4.0),
([0], 0.0),
@@ -72,12 +73,12 @@ def test_insertion():
([1], 0.0),
([2], 4.0),
]
- assert st.get_skeleton(0) == [([0], 0.0), ([1], 0.0), ([2], 4.0)]
+ assert list(st.get_skeleton(0)) == [([0], 0.0), ([1], 0.0), ([2], 4.0)]
# remove_maximal_simplex test
assert st.get_cofaces([0, 1, 2], 1) == []
st.remove_maximal_simplex([0, 1, 2])
- assert st.get_skeleton(2) == [
+ assert list(st.get_skeleton(2)) == [
([0, 1], 0.0),
([0, 2], 4.0),
([0], 0.0),
@@ -92,7 +93,6 @@ def test_insertion():
assert st.find([1]) == True
assert st.find([2]) == True
- st.initialize_filtration()
assert st.persistence(persistence_dim_max=True) == [
(1, (4.0, float("inf"))),
(0, (0.0, float("inf"))),
@@ -126,7 +126,8 @@ def test_expansion():
assert st.num_vertices() == 7
assert st.num_simplices() == 17
- assert st.get_filtration() == [
+
+ assert list(st.get_filtration()) == [
([2], 0.1),
([3], 0.1),
([2, 3], 0.1),
@@ -149,9 +150,8 @@ def test_expansion():
st.expansion(3)
assert st.num_vertices() == 7
assert st.num_simplices() == 22
- st.initialize_filtration()
- assert st.get_filtration() == [
+ assert list(st.get_filtration()) == [
([2], 0.1),
([3], 0.1),
([2, 3], 0.1),
@@ -248,3 +248,399 @@ def test_make_filtration_non_decreasing():
assert st.filtration([3, 4, 5]) == 2.0
assert st.filtration([3, 4]) == 2.0
assert st.filtration([4, 5]) == 2.0
+
+
+def test_extend_filtration():
+
+ # Inserted simplex:
+ # 5 4
+ # o o
+ # / \ /
+ # o o
+ # /2\ /3
+ # o o
+ # 1 0
+
+ st = SimplexTree()
+ st.insert([0, 2])
+ st.insert([1, 2])
+ st.insert([0, 3])
+ st.insert([2, 5])
+ st.insert([3, 4])
+ st.insert([3, 5])
+ st.assign_filtration([0], 1.0)
+ st.assign_filtration([1], 2.0)
+ st.assign_filtration([2], 3.0)
+ st.assign_filtration([3], 4.0)
+ st.assign_filtration([4], 5.0)
+ st.assign_filtration([5], 6.0)
+
+ assert list(st.get_filtration()) == [
+ ([0, 2], 0.0),
+ ([1, 2], 0.0),
+ ([0, 3], 0.0),
+ ([3, 4], 0.0),
+ ([2, 5], 0.0),
+ ([3, 5], 0.0),
+ ([0], 1.0),
+ ([1], 2.0),
+ ([2], 3.0),
+ ([3], 4.0),
+ ([4], 5.0),
+ ([5], 6.0),
+ ]
+
+ st.extend_filtration()
+
+ assert list(st.get_filtration()) == [
+ ([6], -3.0),
+ ([0], -2.0),
+ ([1], -1.8),
+ ([2], -1.6),
+ ([0, 2], -1.6),
+ ([1, 2], -1.6),
+ ([3], -1.4),
+ ([0, 3], -1.4),
+ ([4], -1.2),
+ ([3, 4], -1.2),
+ ([5], -1.0),
+ ([2, 5], -1.0),
+ ([3, 5], -1.0),
+ ([5, 6], 1.0),
+ ([4, 6], 1.2),
+ ([3, 6], 1.4),
+ ([3, 4, 6], 1.4),
+ ([3, 5, 6], 1.4),
+ ([2, 6], 1.6),
+ ([2, 5, 6], 1.6),
+ ([1, 6], 1.8),
+ ([1, 2, 6], 1.8),
+ ([0, 6], 2.0),
+ ([0, 2, 6], 2.0),
+ ([0, 3, 6], 2.0),
+ ]
+
+ dgms = st.extended_persistence(min_persistence=-1.0)
+ assert len(dgms) == 4
+ # Sort by (death-birth) descending - we are only interested in those with the longest life span
+ for idx in range(4):
+ dgms[idx] = sorted(dgms[idx], key=lambda x: (-abs(x[1][0] - x[1][1])))
+
+ assert dgms[0][0][1][0] == pytest.approx(2.0)
+ assert dgms[0][0][1][1] == pytest.approx(3.0)
+ assert dgms[1][0][1][0] == pytest.approx(5.0)
+ assert dgms[1][0][1][1] == pytest.approx(4.0)
+ assert dgms[2][0][1][0] == pytest.approx(1.0)
+ assert dgms[2][0][1][1] == pytest.approx(6.0)
+ assert dgms[3][0][1][0] == pytest.approx(6.0)
+ assert dgms[3][0][1][1] == pytest.approx(1.0)
+
+
+def test_simplices_iterator():
+ st = SimplexTree()
+
+ assert st.insert([0, 1, 2], filtration=4.0) == True
+ assert st.insert([2, 3, 4], filtration=2.0) == True
+
+ for simplex in st.get_simplices():
+ print("simplex is: ", simplex[0])
+ assert st.find(simplex[0]) == True
+ print("filtration is: ", simplex[1])
+ assert st.filtration(simplex[0]) == simplex[1]
+
+
+def test_collapse_edges():
+ st = SimplexTree()
+
+ assert st.insert([0, 1], filtration=1.0) == True
+ assert st.insert([1, 2], filtration=1.0) == True
+ assert st.insert([2, 3], filtration=1.0) == True
+ assert st.insert([0, 3], filtration=1.0) == True
+ assert st.insert([0, 2], filtration=2.0) == True
+ assert st.insert([1, 3], filtration=2.0) == True
+
+ assert st.num_simplices() == 10
+
+ st.collapse_edges()
+ assert st.num_simplices() == 9
+ assert st.find([0, 2]) == False # [1, 3] would be fine as well
+ for simplex in st.get_skeleton(0):
+ assert simplex[1] == 1.0
+
+
+def test_reset_filtration():
+ st = SimplexTree()
+
+ assert st.insert([0, 1, 2], 3.0) == True
+ assert st.insert([0, 3], 2.0) == True
+ assert st.insert([3, 4, 5], 3.0) == True
+ assert st.insert([0, 1, 6, 7], 4.0) == True
+
+ # Guaranteed by construction
+ for simplex in st.get_simplices():
+ assert st.filtration(simplex[0]) >= 2.0
+
+ # dimension until 5 even if simplex tree is of dimension 3 to test the limits
+ for dimension in range(5, -1, -1):
+ st.reset_filtration(0.0, dimension)
+ for simplex in st.get_skeleton(3):
+ print(simplex)
+ if len(simplex[0]) < (dimension) + 1:
+ assert st.filtration(simplex[0]) >= 2.0
+ else:
+ assert st.filtration(simplex[0]) == 0.0
+
+
+def test_boundaries_iterator():
+ st = SimplexTree()
+
+ assert st.insert([0, 1, 2, 3], filtration=1.0) == True
+ assert st.insert([1, 2, 3, 4], filtration=2.0) == True
+
+ assert list(st.get_boundaries([1, 2, 3])) == [([1, 2], 1.0), ([1, 3], 1.0), ([2, 3], 1.0)]
+ assert list(st.get_boundaries([2, 3, 4])) == [([2, 3], 1.0), ([2, 4], 2.0), ([3, 4], 2.0)]
+ assert list(st.get_boundaries([2])) == []
+
+ with pytest.raises(RuntimeError):
+ list(st.get_boundaries([]))
+
+ with pytest.raises(RuntimeError):
+ list(st.get_boundaries([0, 4])) # (0, 4) does not exist
+
+ with pytest.raises(RuntimeError):
+ list(st.get_boundaries([6])) # (6) does not exist
+
+
+def test_persistence_intervals_in_dimension():
+ # Here is our triangulation of a 2-torus - taken from https://dioscuri-tda.org/Paris_TDA_Tutorial_2021.html
+ # 0-----3-----4-----0
+ # | \ | \ | \ | \ |
+ # | \ | \ | \| \ |
+ # 1-----8-----7-----1
+ # | \ | \ | \ | \ |
+ # | \ | \ | \ | \ |
+ # 2-----5-----6-----2
+ # | \ | \ | \ | \ |
+ # | \ | \ | \ | \ |
+ # 0-----3-----4-----0
+ st = SimplexTree()
+ st.insert([0, 1, 8])
+ st.insert([0, 3, 8])
+ st.insert([3, 7, 8])
+ st.insert([3, 4, 7])
+ st.insert([1, 4, 7])
+ st.insert([0, 1, 4])
+ st.insert([1, 2, 5])
+ st.insert([1, 5, 8])
+ st.insert([5, 6, 8])
+ st.insert([6, 7, 8])
+ st.insert([2, 6, 7])
+ st.insert([1, 2, 7])
+ st.insert([0, 2, 3])
+ st.insert([2, 3, 5])
+ st.insert([3, 4, 5])
+ st.insert([4, 5, 6])
+ st.insert([0, 4, 6])
+ st.insert([0, 2, 6])
+ st.compute_persistence(persistence_dim_max=True)
+
+ H0 = st.persistence_intervals_in_dimension(0)
+ assert np.array_equal(H0, np.array([[0.0, float("inf")]]))
+ H1 = st.persistence_intervals_in_dimension(1)
+ assert np.array_equal(H1, np.array([[0.0, float("inf")], [0.0, float("inf")]]))
+ H2 = st.persistence_intervals_in_dimension(2)
+ assert np.array_equal(H2, np.array([[0.0, float("inf")]]))
+ # Test empty case
+ assert st.persistence_intervals_in_dimension(3).shape == (0, 2)
+
+
+def test_equality_operator():
+ st1 = SimplexTree()
+ st2 = SimplexTree()
+
+ assert st1 == st2
+
+ st1.insert([1, 2, 3], 4.0)
+ assert st1 != st2
+
+ st2.insert([1, 2, 3], 4.0)
+ assert st1 == st2
+
+
+def test_simplex_tree_deep_copy():
+ st = SimplexTree()
+ st.insert([1, 2, 3], 0.0)
+ # compute persistence only on the original
+ st.compute_persistence()
+
+ st_copy = st.copy()
+ assert st_copy == st
+ st_filt_list = list(st.get_filtration())
+
+ # check persistence is not copied
+ assert st.__is_persistence_defined() == True
+ assert st_copy.__is_persistence_defined() == False
+
+ # remove something in the copy and check the copy is included in the original
+ st_copy.remove_maximal_simplex([1, 2, 3])
+ a_filt_list = list(st_copy.get_filtration())
+ assert len(a_filt_list) < len(st_filt_list)
+
+ for a_splx in a_filt_list:
+ assert a_splx in st_filt_list
+
+ # test double free
+ del st
+ del st_copy
+
+
+def test_simplex_tree_deep_copy_constructor():
+ st = SimplexTree()
+ st.insert([1, 2, 3], 0.0)
+ # compute persistence only on the original
+ st.compute_persistence()
+
+ st_copy = SimplexTree(st)
+ assert st_copy == st
+ st_filt_list = list(st.get_filtration())
+
+ # check persistence is not copied
+ assert st.__is_persistence_defined() == True
+ assert st_copy.__is_persistence_defined() == False
+
+ # remove something in the copy and check the copy is included in the original
+ st_copy.remove_maximal_simplex([1, 2, 3])
+ a_filt_list = list(st_copy.get_filtration())
+ assert len(a_filt_list) < len(st_filt_list)
+
+ for a_splx in a_filt_list:
+ assert a_splx in st_filt_list
+
+ # test double free
+ del st
+ del st_copy
+
+
+def test_simplex_tree_constructor_exception():
+ with pytest.raises(TypeError):
+ st = SimplexTree(other="Construction from a string shall raise an exception")
+
+
+def test_create_from_array():
+ a = np.array([[1, 4, 13, 6], [4, 3, 11, 5], [13, 11, 10, 12], [6, 5, 12, 2]])
+ st = SimplexTree.create_from_array(a, max_filtration=5.0)
+ assert list(st.get_filtration()) == [([0], 1.0), ([3], 2.0), ([1], 3.0), ([0, 1], 4.0), ([1, 3], 5.0)]
+
+
+def test_insert_edges_from_coo_matrix():
+ try:
+ from scipy.sparse import coo_matrix
+ from scipy.spatial import cKDTree
+ except ImportError:
+ print("Skipping, no SciPy")
+ return
+
+ st = SimplexTree()
+ st.insert([1, 2, 7], 7)
+ row = np.array([2, 5, 3])
+ col = np.array([1, 4, 6])
+ dat = np.array([1, 2, 3])
+ edges = coo_matrix((dat, (row, col)))
+ st.insert_edges_from_coo_matrix(edges)
+ assert list(st.get_filtration()) == [
+ ([1], 1.0),
+ ([2], 1.0),
+ ([1, 2], 1.0),
+ ([4], 2.0),
+ ([5], 2.0),
+ ([4, 5], 2.0),
+ ([3], 3.0),
+ ([6], 3.0),
+ ([3, 6], 3.0),
+ ([7], 7.0),
+ ([1, 7], 7.0),
+ ([2, 7], 7.0),
+ ([1, 2, 7], 7.0),
+ ]
+
+ pts = np.random.rand(100, 2)
+ tree = cKDTree(pts)
+ edges = tree.sparse_distance_matrix(tree, max_distance=0.15, output_type="coo_matrix")
+ st = SimplexTree()
+ st.insert_edges_from_coo_matrix(edges)
+ assert 100 < st.num_simplices() < 1000
+
+
+def test_insert_batch():
+ st = SimplexTree()
+ # vertices
+ st.insert_batch(np.array([[6, 1, 5]]), np.array([-5.0, 2.0, -3.0]))
+ # triangles
+ st.insert_batch(np.array([[2, 10], [5, 0], [6, 11]]), np.array([4.0, 0.0]))
+ # edges
+ st.insert_batch(np.array([[1, 5], [2, 5]]), np.array([1.0, 3.0]))
+
+ assert list(st.get_filtration()) == [
+ ([6], -5.0),
+ ([5], -3.0),
+ ([0], 0.0),
+ ([10], 0.0),
+ ([0, 10], 0.0),
+ ([11], 0.0),
+ ([0, 11], 0.0),
+ ([10, 11], 0.0),
+ ([0, 10, 11], 0.0),
+ ([1], 1.0),
+ ([2], 1.0),
+ ([1, 2], 1.0),
+ ([2, 5], 4.0),
+ ([2, 6], 4.0),
+ ([5, 6], 4.0),
+ ([2, 5, 6], 4.0),
+ ]
+
+
+def test_expansion_with_blocker():
+ st = SimplexTree()
+ st.insert([0, 1], 0)
+ st.insert([0, 2], 1)
+ st.insert([0, 3], 2)
+ st.insert([1, 2], 3)
+ st.insert([1, 3], 4)
+ st.insert([2, 3], 5)
+ st.insert([2, 4], 6)
+ st.insert([3, 6], 7)
+ st.insert([4, 5], 8)
+ st.insert([4, 6], 9)
+ st.insert([5, 6], 10)
+ st.insert([6], 10)
+
+ def blocker(simplex):
+ try:
+ # Block all simplices that contain vertex 6
+ simplex.index(6)
+ print(simplex, " is blocked")
+ return True
+ except ValueError:
+ print(simplex, " is accepted")
+ st.assign_filtration(simplex, st.filtration(simplex) + 1.0)
+ return False
+
+ st.expansion_with_blocker(2, blocker)
+ assert st.num_simplices() == 22
+ assert st.dimension() == 2
+ assert st.find([4, 5, 6]) == False
+ assert st.filtration([0, 1, 2]) == 4.0
+ assert st.filtration([0, 1, 3]) == 5.0
+ assert st.filtration([0, 2, 3]) == 6.0
+ assert st.filtration([1, 2, 3]) == 6.0
+
+ st.expansion_with_blocker(3, blocker)
+ assert st.num_simplices() == 23
+ assert st.dimension() == 3
+ assert st.find([4, 5, 6]) == False
+ assert st.filtration([0, 1, 2]) == 4.0
+ assert st.filtration([0, 1, 3]) == 5.0
+ assert st.filtration([0, 2, 3]) == 6.0
+ assert st.filtration([1, 2, 3]) == 6.0
+ assert st.filtration([0, 1, 2, 3]) == 7.0
diff --git a/src/python/test/test_sklearn_cubical_persistence.py b/src/python/test/test_sklearn_cubical_persistence.py
new file mode 100644
index 00000000..1c05a215
--- /dev/null
+++ b/src/python/test/test_sklearn_cubical_persistence.py
@@ -0,0 +1,59 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Vincent Rouvreau
+
+ Copyright (C) 2021 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.sklearn.cubical_persistence import CubicalPersistence
+import numpy as np
+from sklearn import datasets
+
+CUBICAL_PERSISTENCE_H0_IMG0 = np.array([[0.0, 6.0], [0.0, 8.0], [0.0, np.inf]])
+
+
+def test_simple_constructor_from_top_cells():
+ cells = datasets.load_digits().images[0]
+ cp = CubicalPersistence(homology_dimensions=0)
+ np.testing.assert_array_equal(cp._CubicalPersistence__transform_only_this_dim(cells), CUBICAL_PERSISTENCE_H0_IMG0)
+ cp = CubicalPersistence(homology_dimensions=[0, 2])
+ diags = cp._CubicalPersistence__transform(cells)
+ assert len(diags) == 2
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+
+def test_simple_constructor_from_top_cells_list():
+ digits = datasets.load_digits().images[:10]
+ cp = CubicalPersistence(homology_dimensions=0, n_jobs=-2)
+
+ diags = cp.fit_transform(digits)
+ assert len(diags) == 10
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+ cp = CubicalPersistence(homology_dimensions=[0, 1], n_jobs=-1)
+ diagsH0H1 = cp.fit_transform(digits)
+ assert len(diagsH0H1) == 10
+ for idx in range(10):
+ np.testing.assert_array_equal(diags[idx], diagsH0H1[idx][0])
+
+def test_simple_constructor_from_flattened_cells():
+ cells = datasets.load_digits().images[0]
+ # Not squared (extended) flatten cells
+ flat_cells = np.hstack((cells, np.zeros((cells.shape[0], 2)))).flatten()
+
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
+ diags = cp.fit_transform([flat_cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
+
+ # Not squared (extended) non-flatten cells
+ cells = np.hstack((cells, np.zeros((cells.shape[0], 2))))
+
+ # The aim of this second part of the test is to resize even if not mandatory
+ cp = CubicalPersistence(homology_dimensions=0, newshape=[-1, 8, 10])
+ diags = cp.fit_transform([cells])
+
+ np.testing.assert_array_equal(diags[0], CUBICAL_PERSISTENCE_H0_IMG0)
diff --git a/src/python/test/test_subsampling.py b/src/python/test/test_subsampling.py
index fe0985fa..c1cb4e3f 100755
--- a/src/python/test/test_subsampling.py
+++ b/src/python/test/test_subsampling.py
@@ -16,17 +16,9 @@ __license__ = "MIT"
def test_write_off_file_for_tests():
- file = open("subsample.off", "w")
- file.write("nOFF\n")
- file.write("2 7 0 0\n")
- file.write("1.0 1.0\n")
- file.write("7.0 0.0\n")
- file.write("4.0 6.0\n")
- file.write("9.0 6.0\n")
- file.write("0.0 14.0\n")
- file.write("2.0 19.0\n")
- file.write("9.0 17.0\n")
- file.close()
+ gudhi.write_points_to_off_file(
+ "subsample.off", [[1.0, 1.0], [7.0, 0.0], [4.0, 6.0], [9.0, 6.0], [0.0, 14.0], [2.0, 19.0], [9.0, 17.0]]
+ )
def test_simple_choose_n_farthest_points_with_a_starting_point():
@@ -34,54 +26,29 @@ def test_simple_choose_n_farthest_points_with_a_starting_point():
i = 0
for point in point_set:
# The iteration starts with the given starting point
- sub_set = gudhi.choose_n_farthest_points(
- points=point_set, nb_points=1, starting_point=i
- )
+ sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=1, starting_point=i)
assert sub_set[0] == point_set[i]
i = i + 1
# The iteration finds then the farthest
- sub_set = gudhi.choose_n_farthest_points(
- points=point_set, nb_points=2, starting_point=1
- )
+ sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=2, starting_point=1)
assert sub_set[1] == point_set[3]
- sub_set = gudhi.choose_n_farthest_points(
- points=point_set, nb_points=2, starting_point=3
- )
+ sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=2, starting_point=3)
assert sub_set[1] == point_set[1]
- sub_set = gudhi.choose_n_farthest_points(
- points=point_set, nb_points=2, starting_point=0
- )
+ sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=2, starting_point=0)
assert sub_set[1] == point_set[2]
- sub_set = gudhi.choose_n_farthest_points(
- points=point_set, nb_points=2, starting_point=2
- )
+ sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=2, starting_point=2)
assert sub_set[1] == point_set[0]
# Test the limits
- assert (
- gudhi.choose_n_farthest_points(points=[], nb_points=0, starting_point=0) == []
- )
- assert (
- gudhi.choose_n_farthest_points(points=[], nb_points=1, starting_point=0) == []
- )
- assert (
- gudhi.choose_n_farthest_points(points=[], nb_points=0, starting_point=1) == []
- )
- assert (
- gudhi.choose_n_farthest_points(points=[], nb_points=1, starting_point=1) == []
- )
+ assert gudhi.choose_n_farthest_points(points=[], nb_points=0, starting_point=0) == []
+ assert gudhi.choose_n_farthest_points(points=[], nb_points=1, starting_point=0) == []
+ assert gudhi.choose_n_farthest_points(points=[], nb_points=0, starting_point=1) == []
+ assert gudhi.choose_n_farthest_points(points=[], nb_points=1, starting_point=1) == []
# From off file test
for i in range(0, 7):
- assert (
- len(
- gudhi.choose_n_farthest_points(
- off_file="subsample.off", nb_points=i, starting_point=i
- )
- )
- == i
- )
+ assert len(gudhi.choose_n_farthest_points(off_file="subsample.off", nb_points=i, starting_point=i)) == i
def test_simple_choose_n_farthest_points_randomed():
@@ -91,7 +58,7 @@ def test_simple_choose_n_farthest_points_randomed():
assert gudhi.choose_n_farthest_points(points=[], nb_points=1) == []
assert gudhi.choose_n_farthest_points(points=point_set, nb_points=0) == []
- # Go furter than point set on purpose
+ # Go further than point set on purpose
for iter in range(1, 10):
sub_set = gudhi.choose_n_farthest_points(points=point_set, nb_points=iter)
for sub in sub_set:
@@ -104,10 +71,7 @@ def test_simple_choose_n_farthest_points_randomed():
# From off file test
for i in range(0, 7):
- assert (
- len(gudhi.choose_n_farthest_points(off_file="subsample.off", nb_points=i))
- == i
- )
+ assert len(gudhi.choose_n_farthest_points(off_file="subsample.off", nb_points=i)) == i
def test_simple_pick_n_random_points():
@@ -117,10 +81,9 @@ def test_simple_pick_n_random_points():
assert gudhi.pick_n_random_points(points=[], nb_points=1) == []
assert gudhi.pick_n_random_points(points=point_set, nb_points=0) == []
- # Go furter than point set on purpose
+ # Go further than point set on purpose
for iter in range(1, 10):
sub_set = gudhi.pick_n_random_points(points=point_set, nb_points=iter)
- print(5)
for sub in sub_set:
found = False
for point in point_set:
@@ -131,9 +94,7 @@ def test_simple_pick_n_random_points():
# From off file test
for i in range(0, 7):
- assert (
- len(gudhi.pick_n_random_points(off_file="subsample.off", nb_points=i)) == i
- )
+ assert len(gudhi.pick_n_random_points(off_file="subsample.off", nb_points=i)) == i
def test_simple_sparsify_points():
@@ -142,38 +103,21 @@ def test_simple_sparsify_points():
# assert gudhi.sparsify_point_set(points = [], min_squared_dist = 0.0) == []
# assert gudhi.sparsify_point_set(points = [], min_squared_dist = 10.0) == []
assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=0.0) == point_set
- assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=1.0) == point_set
- assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=2.0) == [
+ assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=0.999) == point_set
+ assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=1.001) == [
[0, 1],
[1, 0],
]
- assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=2.01) == [[0, 1]]
-
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=0.0))
- == 7
- )
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=30.0))
- == 5
- )
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=40.0))
- == 4
- )
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=90.0))
- == 3
- )
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=100.0))
- == 2
- )
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=325.0))
- == 2
- )
- assert (
- len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=325.01))
- == 1
- )
+ assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=1.999) == [
+ [0, 1],
+ [1, 0],
+ ]
+ assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=2.001) == [[0, 1]]
+
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=0.0)) == 7
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=30.0)) == 5
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=40.1)) == 4
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=89.9)) == 3
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=100.0)) == 2
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=324.9)) == 2
+ assert len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=325.01)) == 1
diff --git a/src/python/test/test_tangential_complex.py b/src/python/test/test_tangential_complex.py
index e650e99c..8668a2e0 100755
--- a/src/python/test/test_tangential_complex.py
+++ b/src/python/test/test_tangential_complex.py
@@ -37,7 +37,7 @@ def test_tangential():
assert st.num_simplices() == 6
assert st.num_vertices() == 4
- assert st.get_filtration() == [
+ assert list(st.get_filtration()) == [
([0], 0.0),
([1], 0.0),
([2], 0.0),
@@ -45,6 +45,7 @@ def test_tangential():
([3], 0.0),
([1, 3], 0.0),
]
+
assert st.get_cofaces([0], 1) == [([0, 2], 0.0)]
assert point_list[0] == tc.get_point(0)
diff --git a/src/python/test/test_time_delay.py b/src/python/test/test_time_delay.py
new file mode 100755
index 00000000..1ead9bca
--- /dev/null
+++ b/src/python/test/test_time_delay.py
@@ -0,0 +1,43 @@
+from gudhi.point_cloud.timedelay import TimeDelayEmbedding
+import numpy as np
+
+
+def test_normal():
+ # Sample array
+ ts = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ # Normal case.
+ prep = TimeDelayEmbedding()
+ pointclouds = prep(ts)
+ assert (pointclouds[0] == np.array([1, 2, 3])).all()
+ assert (pointclouds[1] == np.array([2, 3, 4])).all()
+ assert (pointclouds[2] == np.array([3, 4, 5])).all()
+ assert (pointclouds[3] == np.array([4, 5, 6])).all()
+ assert (pointclouds[4] == np.array([5, 6, 7])).all()
+ assert (pointclouds[5] == np.array([6, 7, 8])).all()
+ assert (pointclouds[6] == np.array([7, 8, 9])).all()
+ assert (pointclouds[7] == np.array([8, 9, 10])).all()
+ # Delay = 3
+ prep = TimeDelayEmbedding(delay=3)
+ pointclouds = prep(ts)
+ assert (pointclouds[0] == np.array([1, 4, 7])).all()
+ assert (pointclouds[1] == np.array([2, 5, 8])).all()
+ assert (pointclouds[2] == np.array([3, 6, 9])).all()
+ assert (pointclouds[3] == np.array([4, 7, 10])).all()
+ # Skip = 3
+ prep = TimeDelayEmbedding(skip=3)
+ pointclouds = prep(ts)
+ assert (pointclouds[0] == np.array([1, 2, 3])).all()
+ assert (pointclouds[1] == np.array([4, 5, 6])).all()
+ assert (pointclouds[2] == np.array([7, 8, 9])).all()
+ # Delay = 2 / Skip = 2
+ prep = TimeDelayEmbedding(delay=2, skip=2)
+ pointclouds = prep(ts)
+ assert (pointclouds[0] == np.array([1, 3, 5])).all()
+ assert (pointclouds[1] == np.array([3, 5, 7])).all()
+ assert (pointclouds[2] == np.array([5, 7, 9])).all()
+
+ # Vector series
+ ts = np.arange(0, 10).reshape(-1, 2)
+ prep = TimeDelayEmbedding(dim=4)
+ prep.fit([ts])
+ assert (prep.transform([ts])[0] == [[0, 1, 2, 3], [2, 3, 4, 5], [4, 5, 6, 7], [6, 7, 8, 9]]).all()
diff --git a/src/python/test/test_tomato.py b/src/python/test/test_tomato.py
new file mode 100755
index 00000000..c571f799
--- /dev/null
+++ b/src/python/test/test_tomato.py
@@ -0,0 +1,65 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Marc Glisse
+
+ Copyright (C) 2020 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.clustering.tomato import Tomato
+import numpy as np
+import pytest
+import matplotlib.pyplot as plt
+
+# Disable graphics for testing purposes
+plt.show = lambda: None
+
+
+def test_tomato_1():
+ a = [(1, 2), (1.1, 1.9), (0.9, 1.8), (10, 0), (10.1, 0.05), (10.2, -0.1), (5.4, 0)]
+ t = Tomato(metric="euclidean", n_clusters=2, k=4, n_jobs=-1, eps=0.05)
+ assert np.array_equal(t.fit_predict(a), [1, 1, 1, 0, 0, 0, 0]) # or with swapped 0 and 1
+ assert np.array_equal(t.children_, [[0, 1]])
+
+ t = Tomato(density_type="KDE", r=1, k=4)
+ t.fit(a)
+ assert np.array_equal(t.leaf_labels_, [1, 1, 1, 0, 0, 0, 0]) # or with swapped 0 and 1
+ assert t.n_clusters_ == 2
+ t.merge_threshold_ = 10
+ assert t.n_clusters_ == 1
+ assert (t.labels_ == 0).all()
+
+ t = Tomato(graph_type="radius", r=0.1, metric="cosine", k=3)
+ assert np.array_equal(t.fit_predict(a), [1, 1, 1, 0, 0, 0, 0]) # or with swapped 0 and 1
+
+ t = Tomato(metric="euclidean", graph_type="radius", r=4.7, k=4)
+ t.fit(a)
+ assert t.max_weight_per_cc_.size == 2
+ assert t.neighbors_ == [[0, 1, 2], [0, 1, 2], [0, 1, 2], [3, 4, 5, 6], [3, 4, 5], [3, 4, 5], [3, 6]]
+ t.plot_diagram()
+
+ t = Tomato(graph_type="radius", r=4.7, k=4, symmetrize_graph=True)
+ t.fit(a)
+ assert t.max_weight_per_cc_.size == 2
+ assert [set(i) for i in t.neighbors_] == [{1, 2}, {0, 2}, {0, 1}, {4, 5, 6}, {3, 5}, {3, 4}, {3}]
+
+ t = Tomato(n_clusters=2, k=4, symmetrize_graph=True)
+ t.fit(a)
+ assert [set(i) for i in t.neighbors_] == [
+ {1, 2, 6},
+ {0, 2, 6},
+ {0, 1, 6},
+ {4, 5, 6},
+ {3, 5, 6},
+ {3, 4, 6},
+ {0, 1, 2, 3, 4, 5},
+ ]
+ t.plot_diagram()
+
+ t = Tomato(k=6, metric="manhattan")
+ t.fit(a)
+ assert t.diagram_.size == 0
+ assert t.max_weight_per_cc_.size == 1
+ t.plot_diagram()
diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py
new file mode 100755
index 00000000..f68c748e
--- /dev/null
+++ b/src/python/test/test_wasserstein_barycenter.py
@@ -0,0 +1,46 @@
+from gudhi.wasserstein.barycenter import lagrangian_barycenter
+import numpy as np
+
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Theo Lacombe
+
+ Copyright (C) 2019 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+__author__ = "Theo Lacombe"
+__copyright__ = "Copyright (C) 2019 Inria"
+__license__ = "MIT"
+
+
+def test_lagrangian_barycenter():
+
+ dg1 = np.array([[0.2, 0.5]])
+ dg2 = np.array([[0.2, 0.7]])
+ dg3 = np.array([[0.3, 0.6], [0.7, 0.8], [0.2, 0.3]])
+ dg4 = np.array([])
+ dg5 = np.array([])
+ dg6 = np.array([])
+ res = np.array([[0.27916667, 0.55416667], [0.7375, 0.7625], [0.2375, 0.2625]])
+
+ dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]])
+ dg8 = np.array([[0., 4.], [4, 8]])
+
+ # error crit.
+ eps = 1e-7
+
+
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < eps
+ assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.empty(shape=(0,2)))
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < eps
+ Y, log = lagrangian_barycenter(pdiagset=[dg4, dg8], verbose=True)
+ assert np.linalg.norm(Y - np.array([[1,3], [5, 7]])) < eps
+ assert np.abs(log["energy"] - 2) < eps
+ assert np.array_equal(log["groupings"][0] , np.array([[0, -1], [1, -1]]))
+ assert np.array_equal(log["groupings"][1] , np.array([[0, 0], [1, 1]]))
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg8, dg4], init=np.array([[0.2, 0.6], [0.5, 0.7]]), verbose=False) - np.array([[1, 3], [5, 7]])) < eps
+ assert lagrangian_barycenter(pdiagset = []) is None
+
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 43dda77e..a76b6ce7 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -1,48 +1,201 @@
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
- Author(s): Theo Lacombe
+ Author(s): Theo Lacombe, Marc Glisse
Copyright (C) 2019 Inria
Modification(s):
+ - 2020/07 Théo Lacombe: Added tests about handling essential parts in diagrams.
- YYYY/MM Author: Description of the modification
"""
-from gudhi.wasserstein import wasserstein_distance
+from gudhi.wasserstein.wasserstein import _proj_on_diag, _finite_part, _handle_essential_parts, _get_essential_parts
+from gudhi.wasserstein.wasserstein import _warn_infty
+from gudhi.wasserstein import wasserstein_distance as pot
+from gudhi.hera import wasserstein_distance as hera
import numpy as np
+import pytest
+
__author__ = "Theo Lacombe"
__copyright__ = "Copyright (C) 2019 Inria"
__license__ = "MIT"
-def test_basic_wasserstein():
+def test_proj_on_diag():
+ dgm = np.array([[1., 1.], [1., 2.], [3., 5.]])
+ assert np.array_equal(_proj_on_diag(dgm), [[1., 1.], [1.5, 1.5], [4., 4.]])
+ empty = np.empty((0, 2))
+ assert np.array_equal(_proj_on_diag(empty), empty)
+
+
+def test_finite_part():
+ diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf],
+ [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]])
+ assert np.array_equal(_finite_part(diag), [[0, 1], [3, 5]])
+
+
+def test_handle_essential_parts():
+ diag1 = np.array([[0, 1], [3, 5],
+ [2, np.inf], [3, np.inf],
+ [-np.inf, 8], [-np.inf, 12],
+ [-np.inf, -np.inf],
+ [np.inf, np.inf],
+ [-np.inf, np.inf], [-np.inf, np.inf]])
+
+ diag2 = np.array([[0, 2], [3, 5],
+ [2, np.inf], [4, np.inf],
+ [-np.inf, 8], [-np.inf, 11],
+ [-np.inf, -np.inf],
+ [np.inf, np.inf],
+ [-np.inf, np.inf], [-np.inf, np.inf]])
+
+ diag3 = np.array([[0, 2], [3, 5],
+ [2, np.inf], [4, np.inf], [6, np.inf],
+ [-np.inf, 8], [-np.inf, 11],
+ [-np.inf, -np.inf],
+ [np.inf, np.inf],
+ [-np.inf, np.inf], [-np.inf, np.inf]])
+
+ c, m = _handle_essential_parts(diag1, diag2, order=1)
+ assert c == pytest.approx(2, 0.0001) # Note: here c is only the cost due to essential part (thus 2, not 3)
+ # Similarly, the matching only corresponds to essential parts.
+ # Note that (-inf,-inf) and (+inf,+inf) coordinates are matched to the diagonal.
+ assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, -1], [7, -1], [-1, 6], [-1, 7]])
+
+ c, m = _handle_essential_parts(diag1, diag3, order=1)
+ assert c == np.inf
+ assert (m is None)
+
+
+def test_get_essential_parts():
+ diag1 = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf],
+ [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]])
+
+ diag2 = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf]])
+
+ res = _get_essential_parts(diag1)
+ res2 = _get_essential_parts(diag2)
+ assert np.array_equal(res[0], [4, 5])
+ assert np.array_equal(res[1], [2, 3])
+ assert np.array_equal(res[2], [8, 9])
+ assert np.array_equal(res[3], [6] )
+ assert np.array_equal(res[4], [7] )
+
+ assert np.array_equal(res2[0], [] )
+ assert np.array_equal(res2[1], [2, 3])
+ assert np.array_equal(res2[2], [] )
+ assert np.array_equal(res2[3], [] )
+ assert np.array_equal(res2[4], [] )
+
+
+def test_warn_infty():
+ with pytest.warns(UserWarning):
+ assert _warn_infty(matching=False)==np.inf
+ c, m = _warn_infty(matching=True)
+ assert (c == np.inf)
+ assert (m is None)
+
+
+def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):
diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])
diag3 = np.array([[0, 2], [4, 6]])
diag4 = np.array([[0, 3], [4, 8]])
- emptydiag = np.array([[]])
+ emptydiag = np.array([])
+
+ # We just need to handle positive numbers here
+ def approx(x):
+ return pytest.approx(x, rel=delta)
assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=1.) == 0.
assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=1.) == 0.
assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=2.) == 0.
assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=2.) == 0.
- assert wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.) == 2.
- assert wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.) == 4.
+ assert wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.) == approx(2.)
+ assert wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.) == approx(4.)
+
+ assert wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.) == approx(5.) # thank you Pythagorician triplets
+ assert wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.) == approx(2.5)
+ assert wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.) == approx(3.5355339059327378)
+
+ assert wasserstein_distance(diag1, diag2, internal_p=2., order=1.) == approx(1.4453593023967701)
+ assert wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74) == approx(0.9772734057168739)
+
+ assert wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863) == approx(3.141592214572228)
+
+ assert wasserstein_distance(diag3, diag4, internal_p=1., order=1.) == approx(3.)
+ assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.) == approx(3.) # no diag matching here
+ assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.) == approx(np.sqrt(5))
+ assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == approx(np.sqrt(5))
+ assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == approx(np.sqrt(5))
+
+ if test_infinity:
+ diag5 = np.array([[0, 3], [4, np.inf]])
+ diag6 = np.array([[7, 8], [4, 6], [3, np.inf]])
+
+ assert wasserstein_distance(diag4, diag5) == np.inf
+ assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.)
+ assert wasserstein_distance(diag5, emptydiag) == np.inf
+
+ if test_matching:
+ match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1]
+ assert np.array_equal(match, [])
+ match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
+ assert np.array_equal(match, [])
+ match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1]
+ assert np.array_equal(match , [[-1, 0], [-1, 1]])
+ match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
+ assert np.array_equal(match , [[0, -1], [1, -1]])
+ match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
+ assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]])
+
+ if test_matching and test_infinity:
+ diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]])
+ diag8 = np.array([[0,1], [0, np.inf], [-np.inf, -np.inf], [np.inf, np.inf]])
+ diag9 = np.array([[-np.inf, -np.inf], [np.inf, np.inf]])
+ diag10 = np.array([[0,1], [-np.inf, -np.inf], [np.inf, np.inf]])
+
+ match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1]
+ assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]])
+ match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1]
+ assert (match is None)
+ cost, match = wasserstein_distance(diag7, emptydiag, matching=True, internal_p=2., order=2.3)
+ assert (cost == np.inf)
+ assert (match is None)
+ cost, match = wasserstein_distance(emptydiag, diag7, matching=True, internal_p=2.42, order=2.)
+ assert (cost == np.inf)
+ assert (match is None)
+ cost, match = wasserstein_distance(diag8, diag9, matching=True, internal_p=2., order=2.)
+ assert (cost == np.inf)
+ assert (match is None)
+ cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.)
+ assert (cost == 1)
+ assert (match == [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway.
+ cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.)
+ assert (cost == 0.)
+ assert (match == [[0, -1], [1, -1]])
+
+
+def hera_wrap(**extra):
+ def fun(*kargs,**kwargs):
+ return hera(*kargs,**kwargs,**extra)
+ return fun
+
+
+def pot_wrap(**extra):
+ def fun(*kargs,**kwargs):
+ return pot(*kargs,**kwargs,**extra)
+ return fun
- assert wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.) == 5. # thank you Pythagorician triplets
- assert wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.) == 2.5
- assert wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.) == 3.5355339059327378
- assert wasserstein_distance(diag1, diag2, internal_p=2., order=1.) == 1.4453593023967701
- assert wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74) == 0.9772734057168739
+def test_wasserstein_distance_pot():
+ _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) # pot with its standard args
+ _basic_wasserstein(pot_wrap(enable_autodiff=True, keep_essential_parts=False), 1e-15, test_infinity=False, test_matching=False)
- assert wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863) == 3.141592214572228
- assert wasserstein_distance(diag3, diag4, internal_p=1., order=1.) == 3.
- assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.) == 3. # no diag matching here
- assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.) == np.sqrt(5)
- assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == np.sqrt(5)
- assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == np.sqrt(5)
+def test_wasserstein_distance_hera():
+ _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
+ _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)
diff --git a/src/python/test/test_wasserstein_with_tensors.py b/src/python/test/test_wasserstein_with_tensors.py
new file mode 100755
index 00000000..e3f1411a
--- /dev/null
+++ b/src/python/test/test_wasserstein_with_tensors.py
@@ -0,0 +1,47 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Mathieu Carriere
+
+ Copyright (C) 2020 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.wasserstein import wasserstein_distance as pot
+import numpy as np
+import torch
+import tensorflow as tf
+
+def test_wasserstein_distance_grad():
+ diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True)
+ diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ assert diag1.grad is None and diag2.grad is None and diag3.grad is None
+ dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True)
+ dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True)
+ dist12.backward()
+ dist30.backward()
+ assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any()
+ diag4 = torch.tensor([[0., 10.]], requires_grad=True)
+ diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True)
+ dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True)
+ assert dist45 == 3.
+ dist45.backward()
+ assert np.array_equal(diag4.grad, [[-1., -1.]])
+ assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]])
+ diag6 = torch.tensor([[5., 10.]], requires_grad=True)
+ pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward()
+ # https://github.com/jonasrauber/eagerpy/issues/6
+ # assert np.array_equal(diag6.grad, [[0., 0.]])
+
+def test_wasserstein_distance_grad_tensorflow():
+ with tf.GradientTape() as tape:
+ diag4 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[0., 10.]]), trainable=True))
+ diag5 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[1., 11.], [3., 4.]]), trainable=True))
+ dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True)
+ assert dist45 == 3.
+
+ grads = tape.gradient(dist45, [diag4, diag5])
+ assert np.array_equal(grads[0].values, [[-1., -1.]])
+ assert np.array_equal(grads[1].values, [[1., 1.], [-1., 1.]]) \ No newline at end of file
diff --git a/src/python/test/test_weighted_rips_complex.py b/src/python/test/test_weighted_rips_complex.py
new file mode 100644
index 00000000..7ef48333
--- /dev/null
+++ b/src/python/test/test_weighted_rips_complex.py
@@ -0,0 +1,63 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Yuichi Ike and Masatoshi Takenouchi
+
+ Copyright (C) 2020 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.weighted_rips_complex import WeightedRipsComplex
+from gudhi.point_cloud.dtm import DistanceToMeasure
+import numpy as np
+from math import sqrt
+from scipy.spatial.distance import cdist
+import pytest
+
+def test_non_dtm_rips_complex():
+ dist = [[], [1]]
+ weights = [1, 100]
+ w_rips = WeightedRipsComplex(distance_matrix=dist, weights=weights)
+ st = w_rips.create_simplex_tree(max_dimension=2)
+ assert st.filtration([0,1]) == pytest.approx(200.0)
+
+def test_compatibility_with_rips():
+ distance_matrix = [[0], [1, 0], [1, sqrt(2), 0], [sqrt(2), 1, 1, 0]]
+ w_rips = WeightedRipsComplex(distance_matrix=distance_matrix,max_filtration=42)
+ st = w_rips.create_simplex_tree(max_dimension=1)
+ assert list(st.get_filtration()) == [
+ ([0], 0.0),
+ ([1], 0.0),
+ ([2], 0.0),
+ ([3], 0.0),
+ ([0, 1], 1.0),
+ ([0, 2], 1.0),
+ ([1, 3], 1.0),
+ ([2, 3], 1.0),
+ ([1, 2], sqrt(2)),
+ ([0, 3], sqrt(2)),
+ ]
+
+def test_compatibility_with_filtered_rips():
+ distance_matrix = [[0], [1, 0], [1, sqrt(2), 0], [sqrt(2), 1, 1, 0]]
+ w_rips = WeightedRipsComplex(distance_matrix=distance_matrix,max_filtration=1.0)
+ st = w_rips.create_simplex_tree(max_dimension=1)
+
+ assert st.__is_defined() == True
+ assert st.__is_persistence_defined() == False
+
+ assert st.num_simplices() == 8
+ assert st.num_vertices() == 4
+
+def test_dtm_rips_complex():
+ pts = np.array([[2.0, 2.0], [0.0, 1.0], [3.0, 4.0]])
+ dist = cdist(pts,pts)
+ dtm = DistanceToMeasure(2, q=2, metric="precomputed")
+ r = dtm.fit_transform(dist)
+ w_rips = WeightedRipsComplex(distance_matrix=dist, weights=r)
+ st = w_rips.create_simplex_tree(max_dimension=2)
+ st.persistence()
+ persistence_intervals0 = st.persistence_intervals_in_dimension(0)
+ assert persistence_intervals0 == pytest.approx(np.array([[3.16227766, 5.39834564],[3.16227766, 5.39834564], [3.16227766, float("inf")]]))
+