summaryrefslogtreecommitdiff
path: root/src/python/test/test_representations.py
diff options
context:
space:
mode:
authorwreise <wojciech.reise@epfl.ch>2022-05-25 11:59:37 +0200
committerwreise <wojciech.reise@epfl.ch>2022-05-25 11:59:37 +0200
commitf911ed3882c03b049a18f2a638758cb5ef994bcc (patch)
tree7aee1ab45e2c0623ec4ce9f628dee275e8d244bc /src/python/test/test_representations.py
parentc7afe82e863eb91d122d34f1d34255c04e27b51f (diff)
Add tests for silhouettes
Diffstat (limited to 'src/python/test/test_representations.py')
-rwxr-xr-xsrc/python/test/test_representations.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py
index d219ce7a..8be3a1db 100755
--- a/src/python/test/test_representations.py
+++ b/src/python/test/test_representations.py
@@ -168,3 +168,32 @@ def test_kernel_empty_diagrams():
# PersistenceFisherKernel(bandwidth_fisher=1., bandwidth=1.)(empty_diag, empty_diag)
# PersistenceFisherKernel(bandwidth_fisher=1., bandwidth=1., kernel_approx=RBFSampler(gamma=1./2, n_components=100000).fit(np.ones([1,2])))(empty_diag, empty_diag)
+
+def test_silhouette_permutation_invariance():
+ dgm = _n_diags(1)[0]
+ dgm_permuted = dgm[np.random.permutation(dgm.shape[0]).astype(int)]
+ random_resolution = random.randint(50, 100) * 10
+ slt = Silhouette(resolution=random_resolution, weight=pow(2))
+
+ assert np.all(np.isclose(slt(dgm), slt(dgm_permuted)))
+
+
+def test_silhouette_multiplication_invariance():
+ dgm = _n_diags(1)[0]
+ n_repetitions = np.random.randint(2, high=10)
+ dgm_augmented = np.repeat(dgm, repeats=n_repetitions, axis=0)
+
+ random_resolution = random.randint(50, 100) * 10
+ slt = Silhouette(resolution=random_resolution, weight=pow(2))
+ assert np.all(np.isclose(slt(dgm), slt(dgm_augmented)))
+
+
+def test_silhouette_numeric():
+ dgm = np.array([[2., 3.], [5., 6.]])
+ slt = Silhouette(resolution=9, weight=pow(1))
+ #slt.fit([dgm])
+ # x_values = array([2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6.])
+
+ expected_silhouette = np.array([0., 0.5, 0., 0., 0., 0., 0., 0.5, 0.])/np.sqrt(2)
+ output_silhouette = slt(dgm)
+ assert np.all(np.isclose(output_silhouette, expected_silhouette))