diff options
Diffstat (limited to 'src/python/gudhi/sktda/preprocessing.py')
-rw-r--r-- | src/python/gudhi/sktda/preprocessing.py | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/src/python/gudhi/sktda/preprocessing.py b/src/python/gudhi/sktda/preprocessing.py index 3c625053..784e300f 100644 --- a/src/python/gudhi/sktda/preprocessing.py +++ b/src/python/gudhi/sktda/preprocessing.py @@ -43,8 +43,9 @@ class BirthPersistenceTransform(BaseEstimator, TransformerMixin): """ Xfit = [] for diag in X: - new_diag = np.empty(diag.shape) - np.copyto(new_diag, diag) + #new_diag = np.empty(diag.shape) + #np.copyto(new_diag, diag) + new_diag = np.copy(diag) new_diag[:,1] = new_diag[:,1] - new_diag[:,0] Xfit.append(new_diag) return Xfit @@ -82,7 +83,8 @@ class Clamping(BaseEstimator, TransformerMixin): Returns: Xfit (numpy array of size n): output list of values. """ - Xfit = np.where(X >= self.limit, self.limit * np.ones(X.shape), X) + Xfit = np.minimum(X, self.limit) + #Xfit = np.where(X >= self.limit, self.limit * np.ones(X.shape), X) return Xfit class DiagramScaler(BaseEstimator, TransformerMixin): @@ -91,7 +93,7 @@ class DiagramScaler(BaseEstimator, TransformerMixin): """ def __init__(self, use=False, scalers=[]): """ - Constructor for the DiagramPreprocessor class. + Constructor for the DiagramScaler class. Attributes: use (bool): whether to use the class or not (default False). @@ -102,7 +104,7 @@ class DiagramScaler(BaseEstimator, TransformerMixin): def fit(self, X, y=None): """ - Fit the DiagramPreprocessor class on a list of persistence diagrams: persistence diagrams are concatenated in a big numpy array, and scalers are fit (by calling their fit() method) on their corresponding coordinates in this big array. + Fit the DiagramScaler class on a list of persistence diagrams: persistence diagrams are concatenated in a big numpy array, and scalers are fit (by calling their fit() method) on their corresponding coordinates in this big array. Parameters: X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams. @@ -119,7 +121,7 @@ class DiagramScaler(BaseEstimator, TransformerMixin): def transform(self, X): """ - Apply the DiagramPreprocessor function on the persistence diagrams. The fitted scalers are applied (by calling their transform() method) to their corresponding coordinates in each persistence diagram individually. + Apply the DiagramScaler function on the persistence diagrams. The fitted scalers are applied (by calling their transform() method) to their corresponding coordinates in each persistence diagram individually. Parameters: X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams. @@ -293,7 +295,7 @@ class DiagramSelector(BaseEstimator, TransformerMixin): if self.point_type == "finite": Xfit = [ diag[diag[:,1] < self.limit] if diag.shape[0] != 0 else diag for diag in X] else: - Xfit = [ diag[diag[:,1] == self.limit, 0:1] if diag.shape[0] != 0 else diag for diag in X] + Xfit = [ diag[diag[:,1] >= self.limit, 0:1] if diag.shape[0] != 0 else diag for diag in X] else: Xfit = X return Xfit |