summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2021-04-28 16:11:34 +0200
committerGard Spreemann <gspr@nonempty.org>2021-04-28 16:11:34 +0200
commit7d3fba5d1561b3241b914583ac420434e788e27f (patch)
treeeebea33029714bbe1920461600bea8e3cce155de
parent79f002efaa1584e89f85928e464dd73ea64593b6 (diff)
Handle an empty list of persistence diagrams
-rw-r--r--src/python/gudhi/representations/vector_methods.py6
-rwxr-xr-xsrc/python/test/test_betti_curve_representations.py15
2 files changed, 21 insertions, 0 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index 5133a64c..82f071d7 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -417,6 +417,9 @@ class BettiCurve2(BaseEstimator, TransformerMixin):
"""
if self.predefined_grid is None:
+ if not X:
+ X = [np.zeros((0, 2))]
+
N = len(X)
events = np.concatenate([pd.flatten(order="F") for pd in X], axis=0)
@@ -469,6 +472,9 @@ class BettiCurve2(BaseEstimator, TransformerMixin):
if not self.is_fitted():
raise NotFittedError("Not fitted.")
+ if not X:
+ X = [np.zeros((0, 2))]
+
N = len(X)
events = np.concatenate([pd.flatten(order="F") for pd in X], axis=0)
diff --git a/src/python/test/test_betti_curve_representations.py b/src/python/test/test_betti_curve_representations.py
index 5b95fa2c..475839ee 100755
--- a/src/python/test/test_betti_curve_representations.py
+++ b/src/python/test/test_betti_curve_representations.py
@@ -37,3 +37,18 @@ def test_betti_curve_is_irregular_betti_curve_followed_by_interpolation():
interp = scipy.interpolate.interp1d(bc.grid_, bettis[i, :], kind="previous", fill_value="extrapolate")
bettis_interp = np.array(interp(grid), dtype=int)
assert((bettis_interp == bettis_gridded).all())
+
+
+def test_empty_with_predefined_grid():
+ random_grid = np.sort(np.random.uniform(0, 1, 100))
+ bc = BettiCurve2(random_grid)
+ bettis = bc.fit_transform([])
+ assert((bc.grid_ == random_grid).all())
+ assert((bettis == 0).all())
+
+
+def test_empty():
+ bc = BettiCurve2()
+ bettis = bc.fit_transform([])
+ assert(bc.grid_ == [-np.inf])
+ assert((bettis == 0).all())