summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVincent Rouvreau <10407034+VincentRouvreau@users.noreply.github.com>2023-01-05 10:34:00 +0100
committerGitHub <noreply@github.com>2023-01-05 10:34:00 +0100
commitb6171a12cfdeb26aa782a9dfe5d58177ca2239fb (patch)
tree6089ddf51ed57c5592ba5843d1b5777cd251f114
parent5b38d86d386c0a30eac401a35a5a737e7e49e68b (diff)
parent86689f89bf896e41683fd7b1a4568f2b34ea505d (diff)
Merge pull request #782 from mglisse/sklearn-getparams
Fix get_params for vector_methods
-rw-r--r--src/python/gudhi/representations/vector_methods.py18
-rwxr-xr-xsrc/python/test/test_representations.py6
2 files changed, 14 insertions, 10 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index 745fe1e5..ce74aee5 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -138,13 +138,13 @@ def _trim_endpoints(x, are_endpoints_nan):
def _grid_from_sample_range(self, X):
- sample_range = np.array(self.sample_range_init)
+ sample_range = np.array(self.sample_range)
self.nan_in_range = np.isnan(sample_range)
self.new_resolution = self.resolution
if not self.keep_endpoints:
self.new_resolution += self.nan_in_range.sum()
- self.sample_range = _automatic_sample_range(sample_range, X)
- self.grid_ = np.linspace(self.sample_range[0], self.sample_range[1], self.new_resolution)
+ self.sample_range_fixed = _automatic_sample_range(sample_range, X)
+ self.grid_ = np.linspace(self.sample_range_fixed[0], self.sample_range_fixed[1], self.new_resolution)
if not self.keep_endpoints:
self.grid_ = _trim_endpoints(self.grid_, self.nan_in_range)
@@ -166,7 +166,7 @@ class Landscape(BaseEstimator, TransformerMixin):
sample_range ([double, double]): minimum and maximum of all piecewise-linear function domains, of the form [x_min, x_max] (default [numpy.nan, numpy.nan]). It is the interval on which samples will be drawn evenly. If one of the values is numpy.nan, it can be computed from the persistence diagrams with the fit() method.
keep_endpoints (bool): when computing `sample_range`, use the exact extremities (where the value is always 0). This is mostly useful for plotting, the default is to use a slightly smaller range.
"""
- self.num_landscapes, self.resolution, self.sample_range_init = num_landscapes, resolution, sample_range
+ self.num_landscapes, self.resolution, self.sample_range = num_landscapes, resolution, sample_range
self.keep_endpoints = keep_endpoints
def fit(self, X, y=None):
@@ -240,7 +240,7 @@ class Silhouette(BaseEstimator, TransformerMixin):
sample_range ([double, double]): minimum and maximum for the weighted average domain, of the form [x_min, x_max] (default [numpy.nan, numpy.nan]). It is the interval on which samples will be drawn evenly. If one of the values is numpy.nan, it can be computed from the persistence diagrams with the fit() method.
keep_endpoints (bool): when computing `sample_range`, use the exact extremities (where the value is always 0). This is mostly useful for plotting, the default is to use a slightly smaller range.
"""
- self.weight, self.resolution, self.sample_range_init = weight, resolution, sample_range
+ self.weight, self.resolution, self.sample_range = weight, resolution, sample_range
self.keep_endpoints = keep_endpoints
def fit(self, X, y=None):
@@ -334,7 +334,7 @@ class BettiCurve(BaseEstimator, TransformerMixin):
self.predefined_grid = predefined_grid
self.resolution = resolution
- self.sample_range_init = sample_range
+ self.sample_range = sample_range
self.keep_endpoints = keep_endpoints
def is_fitted(self):
@@ -468,7 +468,7 @@ class Entropy(BaseEstimator, TransformerMixin):
sample_range ([double, double]): minimum and maximum of the entropy summary function domain, of the form [x_min, x_max] (default [numpy.nan, numpy.nan]). It is the interval on which samples will be drawn evenly. If one of the values is numpy.nan, it can be computed from the persistence diagrams with the fit() method. Used only if **mode** = "vector".
keep_endpoints (bool): when computing `sample_range`, use the exact extremities. This is mostly useful for plotting, the default is to use a slightly smaller range.
"""
- self.mode, self.normalized, self.resolution, self.sample_range_init = mode, normalized, resolution, sample_range
+ self.mode, self.normalized, self.resolution, self.sample_range = mode, normalized, resolution, sample_range
self.keep_endpoints = keep_endpoints
def fit(self, X, y=None):
@@ -509,8 +509,8 @@ class Entropy(BaseEstimator, TransformerMixin):
ent = np.zeros(self.resolution)
for j in range(num_pts_in_diag):
[px,py] = orig_diagram[j,:2]
- min_idx = np.clip(np.ceil((px - self.sample_range[0]) / self.step_).astype(int), 0, self.resolution)
- max_idx = np.clip(np.ceil((py - self.sample_range[0]) / self.step_).astype(int), 0, self.resolution)
+ min_idx = np.clip(np.ceil((px - self.sample_range_fixed[0]) / self.step_).astype(int), 0, self.resolution)
+ max_idx = np.clip(np.ceil((py - self.sample_range_fixed[0]) / self.step_).astype(int), 0, self.resolution)
ent[min_idx:max_idx]-=p[j]*np.log(p[j])
if self.normalized:
ent = ent / np.linalg.norm(ent, ord=1)
diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py
index ae0362f8..f4ffbdc1 100755
--- a/src/python/test/test_representations.py
+++ b/src/python/test/test_representations.py
@@ -249,7 +249,7 @@ def test_landscape_nan_range():
dgm = np.array([[2., 6.], [3., 5.]])
lds = Landscape(num_landscapes=2, resolution=9, sample_range=[np.nan, 6.])
lds_dgm = lds(dgm)
- assert (lds.sample_range[0] == 2) & (lds.sample_range[1] == 6)
+ assert (lds.sample_range_fixed[0] == 2) & (lds.sample_range_fixed[1] == 6)
assert lds.new_resolution == 10
def test_endpoints():
@@ -263,3 +263,7 @@ def test_endpoints():
vec = BettiCurve(resolution=None)
vec.fit(diags)
assert np.equal(vec.grid_, [-np.inf, 2., 3.]).all()
+
+def test_get_params():
+ for vec in [ Landscape(), Silhouette(), BettiCurve(), Entropy(mode="vector") ]:
+ vec.get_params()