summaryrefslogtreecommitdiff
path: root/src/python/gudhi/sktda/preprocessing.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/sktda/preprocessing.py')
-rw-r--r--src/python/gudhi/sktda/preprocessing.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/src/python/gudhi/sktda/preprocessing.py b/src/python/gudhi/sktda/preprocessing.py
index 3c625053..784e300f 100644
--- a/src/python/gudhi/sktda/preprocessing.py
+++ b/src/python/gudhi/sktda/preprocessing.py
@@ -43,8 +43,9 @@ class BirthPersistenceTransform(BaseEstimator, TransformerMixin):
"""
Xfit = []
for diag in X:
- new_diag = np.empty(diag.shape)
- np.copyto(new_diag, diag)
+ #new_diag = np.empty(diag.shape)
+ #np.copyto(new_diag, diag)
+ new_diag = np.copy(diag)
new_diag[:,1] = new_diag[:,1] - new_diag[:,0]
Xfit.append(new_diag)
return Xfit
@@ -82,7 +83,8 @@ class Clamping(BaseEstimator, TransformerMixin):
Returns:
Xfit (numpy array of size n): output list of values.
"""
- Xfit = np.where(X >= self.limit, self.limit * np.ones(X.shape), X)
+ Xfit = np.minimum(X, self.limit)
+ #Xfit = np.where(X >= self.limit, self.limit * np.ones(X.shape), X)
return Xfit
class DiagramScaler(BaseEstimator, TransformerMixin):
@@ -91,7 +93,7 @@ class DiagramScaler(BaseEstimator, TransformerMixin):
"""
def __init__(self, use=False, scalers=[]):
"""
- Constructor for the DiagramPreprocessor class.
+ Constructor for the DiagramScaler class.
Attributes:
use (bool): whether to use the class or not (default False).
@@ -102,7 +104,7 @@ class DiagramScaler(BaseEstimator, TransformerMixin):
def fit(self, X, y=None):
"""
- Fit the DiagramPreprocessor class on a list of persistence diagrams: persistence diagrams are concatenated in a big numpy array, and scalers are fit (by calling their fit() method) on their corresponding coordinates in this big array.
+ Fit the DiagramScaler class on a list of persistence diagrams: persistence diagrams are concatenated in a big numpy array, and scalers are fit (by calling their fit() method) on their corresponding coordinates in this big array.
Parameters:
X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams.
@@ -119,7 +121,7 @@ class DiagramScaler(BaseEstimator, TransformerMixin):
def transform(self, X):
"""
- Apply the DiagramPreprocessor function on the persistence diagrams. The fitted scalers are applied (by calling their transform() method) to their corresponding coordinates in each persistence diagram individually.
+ Apply the DiagramScaler function on the persistence diagrams. The fitted scalers are applied (by calling their transform() method) to their corresponding coordinates in each persistence diagram individually.
Parameters:
X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams.
@@ -293,7 +295,7 @@ class DiagramSelector(BaseEstimator, TransformerMixin):
if self.point_type == "finite":
Xfit = [ diag[diag[:,1] < self.limit] if diag.shape[0] != 0 else diag for diag in X]
else:
- Xfit = [ diag[diag[:,1] == self.limit, 0:1] if diag.shape[0] != 0 else diag for diag in X]
+ Xfit = [ diag[diag[:,1] >= self.limit, 0:1] if diag.shape[0] != 0 else diag for diag in X]
else:
Xfit = X
return Xfit