diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2019-11-26 18:03:05 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-11-26 18:03:05 +0100 |
commit | 71c1facc409f08f459c73e15c853782240e51d25 (patch) | |
tree | 27ce4fb94afd3c9f35c897082eef02705428483e /src/python/gudhi/representations/preprocessing.py | |
parent | ec9c03aa2788b66350760e702020948731823148 (diff) | |
parent | 177e80b653d60119acb4455feaba02615083532b (diff) |
Merge pull request #147 from mglisse/sktda-tweaks-glisse
Minor tweaks to representations
Diffstat (limited to 'src/python/gudhi/representations/preprocessing.py')
-rw-r--r-- | src/python/gudhi/representations/preprocessing.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/src/python/gudhi/representations/preprocessing.py b/src/python/gudhi/representations/preprocessing.py index 83227ca1..a39b00e4 100644 --- a/src/python/gudhi/representations/preprocessing.py +++ b/src/python/gudhi/representations/preprocessing.py @@ -30,7 +30,7 @@ class BirthPersistenceTransform(BaseEstimator, TransformerMixin): Fit the BirthPersistenceTransform class on a list of persistence diagrams (this function actually does nothing but is useful when BirthPersistenceTransform is included in a scikit-learn Pipeline). Parameters: - X (n x 2 numpy array): input persistence diagrams. + X (list of n x 2 numpy array): input persistence diagrams. y (n x 1 array): persistence diagram labels (unused). """ return self @@ -58,14 +58,15 @@ class Clamping(BaseEstimator, TransformerMixin): """ This is a class for clamping values. It can be used as a parameter for the DiagramScaler class, for instance if you want to clamp abscissae or ordinates of persistence diagrams. """ - def __init__(self, limit=np.inf): + def __init__(self, minimum=-np.inf, maximum=np.inf): """ Constructor for the Clamping class. Parameters: limit (double): clamping value (default np.inf). """ - self.limit = limit + self.minimum = minimum + self.maximum = maximum def fit(self, X, y=None): """ @@ -87,8 +88,7 @@ class Clamping(BaseEstimator, TransformerMixin): Returns: numpy array of size n: output list of values. """ - Xfit = np.minimum(X, self.limit) - #Xfit = np.where(X >= self.limit, self.limit * np.ones(X.shape), X) + Xfit = np.clip(X, self.minimum, self.maximum) return Xfit class DiagramScaler(BaseEstimator, TransformerMixin): |