diff options
-rw-r--r-- | src/python/gudhi/sktda/preprocessing.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/src/python/gudhi/sktda/preprocessing.py b/src/python/gudhi/sktda/preprocessing.py index 512b02f3..3c625053 100644 --- a/src/python/gudhi/sktda/preprocessing.py +++ b/src/python/gudhi/sktda/preprocessing.py @@ -64,25 +64,25 @@ class Clamping(BaseEstimator, TransformerMixin): def fit(self, X, y=None): """ - Fit the Clamping class on a list of list of values (this function actually does nothing but is useful when Clamping is included in a scikit-learn Pipeline). + Fit the Clamping class on a list of values (this function actually does nothing but is useful when Clamping is included in a scikit-learn Pipeline). Parameters: - X (list of numpy arrays of size n): input values. + X (numpy array of size n): input values. y (n x 1 array): value labels (unused). """ return self def transform(self, X): """ - Clamp each list of values individually. + Clamp list of values. Parameters: - X (list of numpy arrays of size n): input list of list of values. + X (numpy array of size n): input list of values. Returns: - Xfit (list of numpy arrays of size n): output list of list of values. + Xfit (numpy array of size n): output list of values. """ - Xfit = [np.where(L >= self.limit, self.limit * np.ones(L.shape), L) for L in X] + Xfit = np.where(X >= self.limit, self.limit * np.ones(X.shape), X) return Xfit class DiagramScaler(BaseEstimator, TransformerMixin): @@ -132,7 +132,8 @@ class DiagramScaler(BaseEstimator, TransformerMixin): for i in range(len(Xfit)): if Xfit[i].shape[0] > 0: for (indices, scaler) in self.scalers: - Xfit[i][:,indices] = scaler.transform(Xfit[i][:,indices]) + for I in indices: + Xfit[i][:,I] = np.squeeze(scaler.transform(np.reshape(Xfit[i][:,I], [-1,1]))) return Xfit class Padding(BaseEstimator, TransformerMixin): |