summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormathieu <mathieu.carriere3@gmail.com>2019-09-11 20:16:02 -0400
committermathieu <mathieu.carriere3@gmail.com>2019-09-11 20:16:02 -0400
commitaf98fb120eea4ebc09531de9f74684b50212ab7a (patch)
treef7bd8f4564f094ba478050787d358d1008671448
parent2ebe5c2a9e6c82d567109c0da788303298a27357 (diff)
fixed error in DiagramScaler
-rw-r--r--src/python/gudhi/sktda/preprocessing.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/src/python/gudhi/sktda/preprocessing.py b/src/python/gudhi/sktda/preprocessing.py
index 512b02f3..3c625053 100644
--- a/src/python/gudhi/sktda/preprocessing.py
+++ b/src/python/gudhi/sktda/preprocessing.py
@@ -64,25 +64,25 @@ class Clamping(BaseEstimator, TransformerMixin):
def fit(self, X, y=None):
"""
- Fit the Clamping class on a list of list of values (this function actually does nothing but is useful when Clamping is included in a scikit-learn Pipeline).
+ Fit the Clamping class on a list of values (this function actually does nothing but is useful when Clamping is included in a scikit-learn Pipeline).
Parameters:
- X (list of numpy arrays of size n): input values.
+ X (numpy array of size n): input values.
y (n x 1 array): value labels (unused).
"""
return self
def transform(self, X):
"""
- Clamp each list of values individually.
+ Clamp list of values.
Parameters:
- X (list of numpy arrays of size n): input list of list of values.
+ X (numpy array of size n): input list of values.
Returns:
- Xfit (list of numpy arrays of size n): output list of list of values.
+ Xfit (numpy array of size n): output list of values.
"""
- Xfit = [np.where(L >= self.limit, self.limit * np.ones(L.shape), L) for L in X]
+ Xfit = np.where(X >= self.limit, self.limit * np.ones(X.shape), X)
return Xfit
class DiagramScaler(BaseEstimator, TransformerMixin):
@@ -132,7 +132,8 @@ class DiagramScaler(BaseEstimator, TransformerMixin):
for i in range(len(Xfit)):
if Xfit[i].shape[0] > 0:
for (indices, scaler) in self.scalers:
- Xfit[i][:,indices] = scaler.transform(Xfit[i][:,indices])
+ for I in indices:
+ Xfit[i][:,I] = np.squeeze(scaler.transform(np.reshape(Xfit[i][:,I], [-1,1])))
return Xfit
class Padding(BaseEstimator, TransformerMixin):