diff options
author | Vincent Rouvreau <10407034+VincentRouvreau@users.noreply.github.com> | 2020-05-04 14:09:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-04 14:09:48 +0200 |
commit | 81a4e6ff3ba732bd4e061fc5443ffed52b694e01 (patch) | |
tree | 66fc2615c48dd7a82e2ccb43a7dce54d10d087de /src/python/gudhi/representations/preprocessing.py | |
parent | 07a017ca26238847e9d9ab75dcb17e52c81e6865 (diff) | |
parent | b2177e897b575e0c8d17b8ae5ed3259541a06bea (diff) |
Merge pull request #205 from MathieuCarriere/wasserstein_representations
Integration of Wasserstein distances in representations module
Diffstat (limited to 'src/python/gudhi/representations/preprocessing.py')
-rw-r--r-- | src/python/gudhi/representations/preprocessing.py | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/src/python/gudhi/representations/preprocessing.py b/src/python/gudhi/representations/preprocessing.py index a39b00e4..a8545349 100644 --- a/src/python/gudhi/representations/preprocessing.py +++ b/src/python/gudhi/representations/preprocessing.py @@ -54,6 +54,18 @@ class BirthPersistenceTransform(BaseEstimator, TransformerMixin): Xfit.append(new_diag) return Xfit + def __call__(self, diag): + """ + Apply BirthPersistenceTransform on a single persistence diagram and outputs the result. + + Parameters: + diag (n x 2 numpy array): input persistence diagram. + + Returns: + n x 2 numpy array: transformed persistence diagram. + """ + return self.fit_transform([diag])[0] + class Clamping(BaseEstimator, TransformerMixin): """ This is a class for clamping values. It can be used as a parameter for the DiagramScaler class, for instance if you want to clamp abscissae or ordinates of persistence diagrams. @@ -142,6 +154,18 @@ class DiagramScaler(BaseEstimator, TransformerMixin): Xfit[i][:,I] = np.squeeze(scaler.transform(np.reshape(Xfit[i][:,I], [-1,1]))) return Xfit + def __call__(self, diag): + """ + Apply DiagramScaler on a single persistence diagram and outputs the result. + + Parameters: + diag (n x 2 numpy array): input persistence diagram. + + Returns: + n x 2 numpy array: transformed persistence diagram. + """ + return self.fit_transform([diag])[0] + class Padding(BaseEstimator, TransformerMixin): """ This is a class for padding a list of persistence diagrams with dummy points, so that all persistence diagrams end up with the same number of points. @@ -186,6 +210,18 @@ class Padding(BaseEstimator, TransformerMixin): Xfit = X return Xfit + def __call__(self, diag): + """ + Apply Padding on a single persistence diagram and outputs the result. + + Parameters: + diag (n x 2 numpy array): input persistence diagram. + + Returns: + n x 2 numpy array: padded persistence diagram. + """ + return self.fit_transform([diag])[0] + class ProminentPoints(BaseEstimator, TransformerMixin): """ This is a class for removing points that are close or far from the diagonal in persistence diagrams. If persistence diagrams are n x 2 numpy arrays (i.e. persistence diagrams with ordinary features), points are ordered and thresholded by distance-to-diagonal. If persistence diagrams are n x 1 numpy arrays (i.e. persistence diagrams with essential features), points are not ordered and thresholded by first coordinate. @@ -259,6 +295,18 @@ class ProminentPoints(BaseEstimator, TransformerMixin): Xfit = X return Xfit + def __call__(self, diag): + """ + Apply ProminentPoints on a single persistence diagram and outputs the result. + + Parameters: + diag (n x 2 numpy array): input persistence diagram. + + Returns: + n x 2 numpy array: thresholded persistence diagram. + """ + return self.fit_transform([diag])[0] + class DiagramSelector(BaseEstimator, TransformerMixin): """ This is a class for extracting finite or essential points in persistence diagrams. @@ -303,3 +351,15 @@ class DiagramSelector(BaseEstimator, TransformerMixin): else: Xfit = X return Xfit + + def __call__(self, diag): + """ + Apply DiagramSelector on a single persistence diagram and outputs the result. + + Parameters: + diag (n x 2 numpy array): input persistence diagram. + + Returns: + n x 2 numpy array: extracted persistence diagram. + """ + return self.fit_transform([diag])[0] |