diff options
Diffstat (limited to 'src/python/gudhi/sktda/metrics.py')
-rw-r--r-- | src/python/gudhi/sktda/metrics.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/src/python/gudhi/sktda/metrics.py b/src/python/gudhi/sktda/metrics.py index f55f553b..c51b8f3b 100644 --- a/src/python/gudhi/sktda/metrics.py +++ b/src/python/gudhi/sktda/metrics.py @@ -58,7 +58,7 @@ class SlicedWassersteinDistance(BaseEstimator, TransformerMixin): X (list of n x 2 numpy arrays): input persistence diagrams. Returns: - Xfit (numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X)): matrix of pairwise sliced Wasserstein distances. + numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise sliced Wasserstein distances. """ Xfit = np.zeros((len(X), len(self.approx_))) if len(self.diagrams_) == len(X) and np.all([np.array_equal(self.diagrams_[i], X[i]) for i in range(len(X))]): @@ -114,7 +114,7 @@ class BottleneckDistance(BaseEstimator, TransformerMixin): X (list of n x 2 numpy arrays): input persistence diagrams. Returns: - Xfit (numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X)): matrix of pairwise bottleneck distances. + numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise bottleneck distances. """ num_diag1 = len(X) @@ -182,7 +182,7 @@ class PersistenceFisherDistance(BaseEstimator, TransformerMixin): X (list of n x 2 numpy arrays): input persistence diagrams. Returns: - Xfit (numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X)): matrix of pairwise persistence Fisher distances. + numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence Fisher distances. """ Xfit = np.zeros((len(X), len(self.diagrams_))) if len(self.diagrams_) == len(X) and np.all([np.array_equal(self.diagrams_[i], X[i]) for i in range(len(X))]): |