summaryrefslogtreecommitdiff
path: root/src/python/gudhi/sktda/metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/sktda/metrics.py')
-rw-r--r--src/python/gudhi/sktda/metrics.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/src/python/gudhi/sktda/metrics.py b/src/python/gudhi/sktda/metrics.py
index f55f553b..c51b8f3b 100644
--- a/src/python/gudhi/sktda/metrics.py
+++ b/src/python/gudhi/sktda/metrics.py
@@ -58,7 +58,7 @@ class SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
X (list of n x 2 numpy arrays): input persistence diagrams.
Returns:
- Xfit (numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X)): matrix of pairwise sliced Wasserstein distances.
+ numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise sliced Wasserstein distances.
"""
Xfit = np.zeros((len(X), len(self.approx_)))
if len(self.diagrams_) == len(X) and np.all([np.array_equal(self.diagrams_[i], X[i]) for i in range(len(X))]):
@@ -114,7 +114,7 @@ class BottleneckDistance(BaseEstimator, TransformerMixin):
X (list of n x 2 numpy arrays): input persistence diagrams.
Returns:
- Xfit (numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X)): matrix of pairwise bottleneck distances.
+ numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise bottleneck distances.
"""
num_diag1 = len(X)
@@ -182,7 +182,7 @@ class PersistenceFisherDistance(BaseEstimator, TransformerMixin):
X (list of n x 2 numpy arrays): input persistence diagrams.
Returns:
- Xfit (numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X)): matrix of pairwise persistence Fisher distances.
+ numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence Fisher distances.
"""
Xfit = np.zeros((len(X), len(self.diagrams_)))
if len(self.diagrams_) == len(X) and np.all([np.array_equal(self.diagrams_[i], X[i]) for i in range(len(X))]):