From ce58cc97866605fe64df479e96d455e90f56f8e2 Mon Sep 17 00:00:00 2001 From: MathieuCarriere Date: Sun, 8 Dec 2019 21:22:09 -0500 Subject: fixed useless coordinates in Landscape if min and max are computed from data --- src/python/gudhi/representations/vector_methods.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'src/python/gudhi') diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py index 61c4fb84..083551a4 100644 --- a/src/python/gudhi/representations/vector_methods.py +++ b/src/python/gudhi/representations/vector_methods.py @@ -104,10 +104,11 @@ class Landscape(BaseEstimator, TransformerMixin): X (list of n x 2 numpy arrays): input persistence diagrams. y (n x 1 array): persistence diagram labels (unused). """ + self.nan_in_range = np.isnan(np.array(self.sample_range)) if np.isnan(np.array(self.sample_range)).any(): pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(X,y) [mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]] - self.sample_range = np.where(np.isnan(np.array(self.sample_range)), np.array([mx, My]), np.array(self.sample_range)) + self.sample_range = np.where(self.nan_in_range, np.array([mx, My]), np.array(self.sample_range)) return self def transform(self, X): @@ -121,7 +122,7 @@ class Landscape(BaseEstimator, TransformerMixin): numpy array with shape (number of diagrams) x (number of samples = **num_landscapes** x **resolution**): output persistence landscapes. """ num_diag, Xfit = len(X), [] - x_values = np.linspace(self.sample_range[0], self.sample_range[1], self.resolution) + x_values = np.linspace(self.sample_range[0], self.sample_range[1], self.resolution + self.nan_in_range.sum()) step_x = x_values[1] - x_values[0] for i in range(num_diag): @@ -157,7 +158,12 @@ class Landscape(BaseEstimator, TransformerMixin): for k in range( min(self.num_landscapes, len(events[j])) ): ls[k,j] = events[j][k] - Xfit.append(np.sqrt(2)*np.reshape(ls,[1,-1])) + if self.nan_in_range[0]: + ls = ls[:,1:] + if self.nan_in_range[1]: + ls = ls[:,:-1] + ls = np.sqrt(2)*np.reshape(ls,[1,-1]) + Xfit.append(ls) Xfit = np.concatenate(Xfit,0) -- cgit v1.2.3