From af98fb120eea4ebc09531de9f74684b50212ab7a Mon Sep 17 00:00:00 2001 From: mathieu Date: Wed, 11 Sep 2019 20:16:02 -0400 Subject: fixed error in DiagramScaler --- src/python/gudhi/sktda/preprocessing.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/python/gudhi/sktda/preprocessing.py b/src/python/gudhi/sktda/preprocessing.py index 512b02f3..3c625053 100644 --- a/src/python/gudhi/sktda/preprocessing.py +++ b/src/python/gudhi/sktda/preprocessing.py @@ -64,25 +64,25 @@ class Clamping(BaseEstimator, TransformerMixin): def fit(self, X, y=None): """ - Fit the Clamping class on a list of list of values (this function actually does nothing but is useful when Clamping is included in a scikit-learn Pipeline). + Fit the Clamping class on a list of values (this function actually does nothing but is useful when Clamping is included in a scikit-learn Pipeline). Parameters: - X (list of numpy arrays of size n): input values. + X (numpy array of size n): input values. y (n x 1 array): value labels (unused). """ return self def transform(self, X): """ - Clamp each list of values individually. + Clamp list of values. Parameters: - X (list of numpy arrays of size n): input list of list of values. + X (numpy array of size n): input list of values. Returns: - Xfit (list of numpy arrays of size n): output list of list of values. + Xfit (numpy array of size n): output list of values. """ - Xfit = [np.where(L >= self.limit, self.limit * np.ones(L.shape), L) for L in X] + Xfit = np.where(X >= self.limit, self.limit * np.ones(X.shape), X) return Xfit class DiagramScaler(BaseEstimator, TransformerMixin): @@ -132,7 +132,8 @@ class DiagramScaler(BaseEstimator, TransformerMixin): for i in range(len(Xfit)): if Xfit[i].shape[0] > 0: for (indices, scaler) in self.scalers: - Xfit[i][:,indices] = scaler.transform(Xfit[i][:,indices]) + for I in indices: + Xfit[i][:,I] = np.squeeze(scaler.transform(np.reshape(Xfit[i][:,I], [-1,1]))) return Xfit class Padding(BaseEstimator, TransformerMixin): -- cgit v1.2.3