summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/preprocessing.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2019-11-26 18:03:05 +0100
committerGitHub <noreply@github.com>2019-11-26 18:03:05 +0100
commit71c1facc409f08f459c73e15c853782240e51d25 (patch)
tree27ce4fb94afd3c9f35c897082eef02705428483e /src/python/gudhi/representations/preprocessing.py
parentec9c03aa2788b66350760e702020948731823148 (diff)
parent177e80b653d60119acb4455feaba02615083532b (diff)
Merge pull request #147 from mglisse/sktda-tweaks-glisse
Minor tweaks to representations
Diffstat (limited to 'src/python/gudhi/representations/preprocessing.py')
-rw-r--r--src/python/gudhi/representations/preprocessing.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/src/python/gudhi/representations/preprocessing.py b/src/python/gudhi/representations/preprocessing.py
index 83227ca1..a39b00e4 100644
--- a/src/python/gudhi/representations/preprocessing.py
+++ b/src/python/gudhi/representations/preprocessing.py
@@ -30,7 +30,7 @@ class BirthPersistenceTransform(BaseEstimator, TransformerMixin):
Fit the BirthPersistenceTransform class on a list of persistence diagrams (this function actually does nothing but is useful when BirthPersistenceTransform is included in a scikit-learn Pipeline).
Parameters:
- X (n x 2 numpy array): input persistence diagrams.
+ X (list of n x 2 numpy array): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
return self
@@ -58,14 +58,15 @@ class Clamping(BaseEstimator, TransformerMixin):
"""
This is a class for clamping values. It can be used as a parameter for the DiagramScaler class, for instance if you want to clamp abscissae or ordinates of persistence diagrams.
"""
- def __init__(self, limit=np.inf):
+ def __init__(self, minimum=-np.inf, maximum=np.inf):
"""
Constructor for the Clamping class.
Parameters:
limit (double): clamping value (default np.inf).
"""
- self.limit = limit
+ self.minimum = minimum
+ self.maximum = maximum
def fit(self, X, y=None):
"""
@@ -87,8 +88,7 @@ class Clamping(BaseEstimator, TransformerMixin):
Returns:
numpy array of size n: output list of values.
"""
- Xfit = np.minimum(X, self.limit)
- #Xfit = np.where(X >= self.limit, self.limit * np.ones(X.shape), X)
+ Xfit = np.clip(X, self.minimum, self.maximum)
return Xfit
class DiagramScaler(BaseEstimator, TransformerMixin):