diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-05-07 15:52:56 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-05-07 15:52:56 +0200 |
commit | a5044e7c6465afd7d1c368b697c559995742552c (patch) | |
tree | 451e8b1cf1c7b813d11376d6d813447646bb47c8 /src/python/gudhi/representations/preprocessing.py | |
parent | d61bfd349274456f8d7e0ccd64839a2d84eea0a0 (diff) | |
parent | 89b34f069e632a8fea0642556a4010de821ed6c9 (diff) |
Merge remote-tracking branch 'origin/master' into bottlepy2
Diffstat (limited to 'src/python/gudhi/representations/preprocessing.py')
-rw-r--r-- | src/python/gudhi/representations/preprocessing.py | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/src/python/gudhi/representations/preprocessing.py b/src/python/gudhi/representations/preprocessing.py index a39b00e4..a8545349 100644 --- a/src/python/gudhi/representations/preprocessing.py +++ b/src/python/gudhi/representations/preprocessing.py @@ -54,6 +54,18 @@ class BirthPersistenceTransform(BaseEstimator, TransformerMixin): Xfit.append(new_diag) return Xfit + def __call__(self, diag): + """ + Apply BirthPersistenceTransform on a single persistence diagram and outputs the result. + + Parameters: + diag (n x 2 numpy array): input persistence diagram. + + Returns: + n x 2 numpy array: transformed persistence diagram. + """ + return self.fit_transform([diag])[0] + class Clamping(BaseEstimator, TransformerMixin): """ This is a class for clamping values. It can be used as a parameter for the DiagramScaler class, for instance if you want to clamp abscissae or ordinates of persistence diagrams. @@ -142,6 +154,18 @@ class DiagramScaler(BaseEstimator, TransformerMixin): Xfit[i][:,I] = np.squeeze(scaler.transform(np.reshape(Xfit[i][:,I], [-1,1]))) return Xfit + def __call__(self, diag): + """ + Apply DiagramScaler on a single persistence diagram and outputs the result. + + Parameters: + diag (n x 2 numpy array): input persistence diagram. + + Returns: + n x 2 numpy array: transformed persistence diagram. + """ + return self.fit_transform([diag])[0] + class Padding(BaseEstimator, TransformerMixin): """ This is a class for padding a list of persistence diagrams with dummy points, so that all persistence diagrams end up with the same number of points. @@ -186,6 +210,18 @@ class Padding(BaseEstimator, TransformerMixin): Xfit = X return Xfit + def __call__(self, diag): + """ + Apply Padding on a single persistence diagram and outputs the result. + + Parameters: + diag (n x 2 numpy array): input persistence diagram. + + Returns: + n x 2 numpy array: padded persistence diagram. + """ + return self.fit_transform([diag])[0] + class ProminentPoints(BaseEstimator, TransformerMixin): """ This is a class for removing points that are close or far from the diagonal in persistence diagrams. If persistence diagrams are n x 2 numpy arrays (i.e. persistence diagrams with ordinary features), points are ordered and thresholded by distance-to-diagonal. If persistence diagrams are n x 1 numpy arrays (i.e. persistence diagrams with essential features), points are not ordered and thresholded by first coordinate. @@ -259,6 +295,18 @@ class ProminentPoints(BaseEstimator, TransformerMixin): Xfit = X return Xfit + def __call__(self, diag): + """ + Apply ProminentPoints on a single persistence diagram and outputs the result. + + Parameters: + diag (n x 2 numpy array): input persistence diagram. + + Returns: + n x 2 numpy array: thresholded persistence diagram. + """ + return self.fit_transform([diag])[0] + class DiagramSelector(BaseEstimator, TransformerMixin): """ This is a class for extracting finite or essential points in persistence diagrams. @@ -303,3 +351,15 @@ class DiagramSelector(BaseEstimator, TransformerMixin): else: Xfit = X return Xfit + + def __call__(self, diag): + """ + Apply DiagramSelector on a single persistence diagram and outputs the result. + + Parameters: + diag (n x 2 numpy array): input persistence diagram. + + Returns: + n x 2 numpy array: extracted persistence diagram. + """ + return self.fit_transform([diag])[0] |