diff options
Diffstat (limited to 'src/python/test/test_representations.py')
-rwxr-xr-x | src/python/test/test_representations.py | 171 |
1 files changed, 170 insertions, 1 deletions
diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py index cda1a15b..f4ffbdc1 100755 --- a/src/python/test/test_representations.py +++ b/src/python/test/test_representations.py @@ -3,9 +3,23 @@ import sys import matplotlib.pyplot as plt import numpy as np import pytest +import random from sklearn.cluster import KMeans +# Vectorization +from gudhi.representations import (Landscape, Silhouette, BettiCurve, ComplexPolynomial,\ + TopologicalVector, PersistenceImage, Entropy) + +# Preprocessing +from gudhi.representations import (BirthPersistenceTransform, Clamping, DiagramScaler, Padding, ProminentPoints, \ + DiagramSelector) + +# Kernel +from gudhi.representations import (PersistenceWeightedGaussianKernel, \ + PersistenceScaleSpaceKernel, SlicedWassersteinDistance,\ + SlicedWassersteinKernel, PersistenceFisherKernel, WassersteinDistance) + def test_representations_examples(): # Disable graphics for testing purposes @@ -91,10 +105,165 @@ def test_dummy_atol(): from gudhi.representations.vector_methods import BettiCurve - def test_infinity(): a = np.array([[1.0, 8.0], [2.0, np.inf], [3.0, 4.0]]) c = BettiCurve(20, [0.0, 10.0])(a) assert c[1] == 0 assert c[7] == 3 assert c[9] == 2 + +def test_preprocessing_empty_diagrams(): + empty_diag = np.empty(shape = [0, 2]) + assert not np.any(BirthPersistenceTransform()(empty_diag)) + assert not np.any(Clamping().fit_transform(empty_diag)) + assert not np.any(DiagramScaler()(empty_diag)) + assert not np.any(Padding()(empty_diag)) + assert not np.any(ProminentPoints()(empty_diag)) + assert not np.any(DiagramSelector()(empty_diag)) + +def pow(n): + return lambda x: np.power(x[1]-x[0],n) + +def test_vectorization_empty_diagrams(): + empty_diag = np.empty(shape = [0, 2]) + random_resolution = random.randint(50,100)*10 # between 500 and 1000 + print("resolution = ", random_resolution) + lsc = Landscape(resolution=random_resolution)(empty_diag) + assert not np.any(lsc) + assert lsc.shape[0]%random_resolution == 0 + slt = Silhouette(resolution=random_resolution, weight=pow(2))(empty_diag) + assert not np.any(slt) + assert slt.shape[0] == random_resolution + btc = BettiCurve(resolution=random_resolution)(empty_diag) + assert not np.any(btc) + assert btc.shape[0] == random_resolution + cpp = ComplexPolynomial(threshold=random_resolution, polynomial_type="T")(empty_diag) + assert not np.any(cpp) + assert cpp.shape[0] == random_resolution + tpv = TopologicalVector(threshold=random_resolution)(empty_diag) + assert tpv.shape[0] == random_resolution + assert not np.any(tpv) + prmg = PersistenceImage(resolution=[random_resolution,random_resolution])(empty_diag) + assert not np.any(prmg) + assert prmg.shape[0] == random_resolution * random_resolution + sce = Entropy(mode="scalar", resolution=random_resolution)(empty_diag) + assert not np.any(sce) + assert sce.shape[0] == 1 + scv = Entropy(mode="vector", normalized=False, resolution=random_resolution)(empty_diag) + assert not np.any(scv) + assert scv.shape[0] == random_resolution + +def test_entropy_miscalculation(): + diag_ex = np.array([[0.0,1.0], [0.0,1.0], [0.0,2.0]]) + def pe(pd): + l = pd[:,1] - pd[:,0] + l = l/sum(l) + return -np.dot(l, np.log(l)) + sce = Entropy(mode="scalar") + assert [[pe(diag_ex)]] == sce.fit_transform([diag_ex]) + sce = Entropy(mode="vector", resolution=4, normalized=False, keep_endpoints=True) + pef = [-1/4*np.log(1/4)-1/4*np.log(1/4)-1/2*np.log(1/2), + -1/4*np.log(1/4)-1/4*np.log(1/4)-1/2*np.log(1/2), + -1/2*np.log(1/2), + 0.0] + assert all(([pef] == sce.fit_transform([diag_ex]))[0]) + sce = Entropy(mode="vector", resolution=4, normalized=True) + pefN = (sce.fit_transform([diag_ex]))[0] + area = np.linalg.norm(pefN, ord=1) + assert area==pytest.approx(1) + +def test_kernel_empty_diagrams(): + empty_diag = np.empty(shape = [0, 2]) + assert SlicedWassersteinDistance(num_directions=100)(empty_diag, empty_diag) == 0. + assert SlicedWassersteinKernel(num_directions=100, bandwidth=1.)(empty_diag, empty_diag) == 1. + assert WassersteinDistance(mode="hera", delta=0.0001)(empty_diag, empty_diag) == 0. + assert WassersteinDistance(mode="pot")(empty_diag, empty_diag) == 0. + assert BottleneckDistance(epsilon=.001)(empty_diag, empty_diag) == 0. + assert BottleneckDistance()(empty_diag, empty_diag) == 0. +# PersistenceWeightedGaussianKernel(bandwidth=1., kernel_approx=None, weight=arctan(1.,1.))(empty_diag, empty_diag) +# PersistenceWeightedGaussianKernel(kernel_approx=RBFSampler(gamma=1./2, n_components=100000).fit(np.ones([1,2])), weight=arctan(1.,1.))(empty_diag, empty_diag) +# PersistenceScaleSpaceKernel(bandwidth=1.)(empty_diag, empty_diag) +# PersistenceScaleSpaceKernel(kernel_approx=RBFSampler(gamma=1./2, n_components=100000).fit(np.ones([1,2])))(empty_diag, empty_diag) +# PersistenceFisherKernel(bandwidth_fisher=1., bandwidth=1.)(empty_diag, empty_diag) +# PersistenceFisherKernel(bandwidth_fisher=1., bandwidth=1., kernel_approx=RBFSampler(gamma=1./2, n_components=100000).fit(np.ones([1,2])))(empty_diag, empty_diag) + + +def test_silhouette_permutation_invariance(): + dgm = _n_diags(1)[0] + dgm_permuted = dgm[np.random.permutation(dgm.shape[0]).astype(int)] + random_resolution = random.randint(50, 100) * 10 + slt = Silhouette(resolution=random_resolution, weight=pow(2)) + + assert np.all(np.isclose(slt(dgm), slt(dgm_permuted))) + + +def test_silhouette_multiplication_invariance(): + dgm = _n_diags(1)[0] + n_repetitions = np.random.randint(2, high=10) + dgm_augmented = np.repeat(dgm, repeats=n_repetitions, axis=0) + + random_resolution = random.randint(50, 100) * 10 + slt = Silhouette(resolution=random_resolution, weight=pow(2)) + assert np.all(np.isclose(slt(dgm), slt(dgm_augmented))) + + +def test_silhouette_numeric(): + dgm = np.array([[2., 3.], [5., 6.]]) + slt = Silhouette(resolution=9, weight=pow(1), sample_range=[2., 6.]) + #slt.fit([dgm]) + # x_values = array([2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6.]) + + expected_silhouette = np.array([0., 0.5, 0., 0., 0., 0., 0., 0.5, 0.])/np.sqrt(2) + output_silhouette = slt(dgm) + assert np.all(np.isclose(output_silhouette, expected_silhouette)) + + +def test_landscape_small_persistence_invariance(): + dgm = np.array([[2., 6.], [2., 5.], [3., 7.]]) + small_persistence_pts = np.random.rand(10, 2) + small_persistence_pts[:, 1] += small_persistence_pts[:, 0] + small_persistence_pts += np.min(dgm) + dgm_augmented = np.concatenate([dgm, small_persistence_pts], axis=0) + + lds = Landscape(num_landscapes=2, resolution=5) + lds_dgm, lds_dgm_augmented = lds(dgm), lds(dgm_augmented) + + assert np.all(np.isclose(lds_dgm, lds_dgm_augmented)) + + +def test_landscape_numeric(): + dgm = np.array([[2., 6.], [3., 5.]]) + lds_ref = np.array([ + 0., 0.5, 1., 1.5, 2., 1.5, 1., 0.5, 0., # tent of [2, 6] + 0., 0., 0., 0.5, 1., 0.5, 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., + ]) + lds_ref *= np.sqrt(2) + lds = Landscape(num_landscapes=4, resolution=9, sample_range=[2., 6.]) + lds_dgm = lds(dgm) + assert np.all(np.isclose(lds_dgm, lds_ref)) + + +def test_landscape_nan_range(): + dgm = np.array([[2., 6.], [3., 5.]]) + lds = Landscape(num_landscapes=2, resolution=9, sample_range=[np.nan, 6.]) + lds_dgm = lds(dgm) + assert (lds.sample_range_fixed[0] == 2) & (lds.sample_range_fixed[1] == 6) + assert lds.new_resolution == 10 + +def test_endpoints(): + diags = [ np.array([[2., 3.]]) ] + for vec in [ Landscape(), Silhouette(), BettiCurve(), Entropy(mode="vector") ]: + vec.fit(diags) + assert vec.grid_[0] > 2 and vec.grid_[-1] < 3 + for vec in [ Landscape(keep_endpoints=True), Silhouette(keep_endpoints=True), BettiCurve(keep_endpoints=True), Entropy(mode="vector", keep_endpoints=True)]: + vec.fit(diags) + assert vec.grid_[0] == 2 and vec.grid_[-1] == 3 + vec = BettiCurve(resolution=None) + vec.fit(diags) + assert np.equal(vec.grid_, [-np.inf, 2., 3.]).all() + +def test_get_params(): + for vec in [ Landscape(), Silhouette(), BettiCurve(), Entropy(mode="vector") ]: + vec.get_params() |