diff options
author | yuichi-ike <yuichi.ike.1990@gmail.com> | 2020-04-07 09:36:03 +0900 |
---|---|---|
committer | yuichi-ike <yuichi.ike.1990@gmail.com> | 2020-04-07 09:36:03 +0900 |
commit | 4294e5fc6e1bff246a7d22f1bd98f91b62f14163 (patch) | |
tree | cffe6fc9df2e2da9f839486799aaa8f986a5730c /src/python | |
parent | a4fa5f673784a842e9fac13003c843d454c888a4 (diff) |
filtration value fixed
Diffstat (limited to 'src/python')
-rw-r--r-- | src/python/gudhi/weighted_rips_complex.py | 2 | ||||
-rw-r--r-- | src/python/test/test_weighted_rips.py | 12 |
2 files changed, 12 insertions, 2 deletions
diff --git a/src/python/gudhi/weighted_rips_complex.py b/src/python/gudhi/weighted_rips_complex.py index 9df2ddf9..7e504b2c 100644 --- a/src/python/gudhi/weighted_rips_complex.py +++ b/src/python/gudhi/weighted_rips_complex.py @@ -51,7 +51,7 @@ class WeightedRipsComplex: st.insert([i], F[i]) for i in range(num_pts): for j in range(i): - value = (dist[i][j] + F[i] + F[j]) / 2 + value = max(F[i], F[j], (dist[i][j] + F[i] + F[j]) / 2) if value < self.max_filtration: st.insert([i,j], filtration=value) diff --git a/src/python/test/test_weighted_rips.py b/src/python/test/test_weighted_rips.py index 7896fb78..a3235276 100644 --- a/src/python/test/test_weighted_rips.py +++ b/src/python/test/test_weighted_rips.py @@ -14,13 +14,23 @@ import numpy as np from scipy.spatial.distance import cdist import pytest +def test_non_dtm_rips_complex(): + dist = [[], [1]] + weights = [1, 100] + w_rips = WeightedRipsComplex(distance_matrix=dist, weights=weights) + st = w_rips.create_simplex_tree(max_dimension=2) + assert st.filtration([0,1]) == pytest.approx(100.0) + + def test_dtm_rips_complex(): pts = np.array([[2.0, 2], [0, 1], [3, 4]]) dist = cdist(pts,pts) dtm = DTM(2, q=2, metric="precomputed") r = dtm.fit_transform(dist) - w_rips = WeightedRipsComplex(distance_mattix=dist, weights=r) + w_rips = WeightedRipsComplex(distance_matrix=dist, weights=r) st = w_rips.create_simplex_tree(max_dimension=2) + st.persistence() persistence_intervals0 = st.persistence_intervals_in_dimension(0) assert persistence_intervals0 == pytest.approx(np.array([[1.58113883, 2.69917282],[1.58113883, 2.69917282], [1.58113883, float("inf")]])) + |