summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2020-04-28 10:31:12 -0400
committerMathieuCarriere <mathieu.carriere3@gmail.com>2020-04-28 10:31:12 -0400
commit51ca2370ae27bec052bb4fffafc8a718ac306264 (patch)
tree67e1ea14d4c290295b0298dba0d701303a53f3b6 /src/python/gudhi/representations
parent910f29401e053f668e4d277af098295ab05e6022 (diff)
parent0fb22e4c499b665ad505e5d9d2c325f7561f69c4 (diff)
fix conflict
Diffstat (limited to 'src/python/gudhi/representations')
-rw-r--r--src/python/gudhi/representations/metrics.py60
1 files changed, 30 insertions, 30 deletions
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index e2c30f8c..59440b1a 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -246,27 +246,23 @@ class BottleneckDistance(BaseEstimator, TransformerMixin):
Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric="bottleneck", e=self.epsilon)
return Xfit
-class WassersteinDistance(BaseEstimator, TransformerMixin):
+class PersistenceFisherDistance(BaseEstimator, TransformerMixin):
"""
- This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams.
+ This is a class for computing the persistence Fisher distance matrix from a list of persistence diagrams. The persistence Fisher distance is obtained by computing the original Fisher distance between the probability distributions associated to the persistence diagrams given by convolving them with a Gaussian kernel. See http://papers.nips.cc/paper/8205-persistence-fisher-kernel-a-riemannian-manifold-kernel-for-persistence-diagrams for more details.
"""
- def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01):
+ def __init__(self, bandwidth=1., kernel_approx=None):
"""
- Constructor for the WassersteinDistance class.
+ Constructor for the PersistenceFisherDistance class.
Parameters:
- order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`.
- internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`.
- mode (str): method for computing Wasserstein distance. Either "pot" or "hera".
- delta (float): relative error 1+delta. Used only if mode == "hera".
+ bandwidth (double): bandwidth of the Gaussian kernel used to turn persistence diagrams into probability distributions (default 1.).
+ kernel_approx (class): kernel approximation class used to speed up computation (default None). Common kernel approximations classes can be found in the scikit-learn library (such as RBFSampler for instance).
"""
- self.order, self.internal_p, self.mode = order, internal_p, mode
- self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein"
- self.delta = delta
+ self.bandwidth, self.kernel_approx = bandwidth, kernel_approx
def fit(self, X, y=None):
"""
- Fit the WassersteinDistance class on a list of persistence diagrams: persistence diagrams are stored in a numpy array called **diagrams**.
+ Fit the PersistenceFisherDistance class on a list of persistence diagrams: persistence diagrams are stored in a numpy array called **diagrams** and the kernel approximation class (if not None) is applied on them.
Parameters:
X (list of n x 2 numpy arrays): input persistence diagrams.
@@ -277,37 +273,37 @@ class WassersteinDistance(BaseEstimator, TransformerMixin):
def transform(self, X):
"""
- Compute all Wasserstein distances between the persistence diagrams that were stored after calling the fit() method, and a given list of (possibly different) persistence diagrams.
+ Compute all persistence Fisher distances between the persistence diagrams that were stored after calling the fit() method, and a given list of (possibly different) persistence diagrams.
Parameters:
X (list of n x 2 numpy arrays): input persistence diagrams.
Returns:
- numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise Wasserstein distances.
+ numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence Fisher distances.
"""
- if self.metric == "hera_wasserstein":
- Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p, delta=self.delta)
- else:
- Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p)
- return Xfit
+ return pairwise_persistence_diagram_distances(X, self.diagrams_, metric="persistence_fisher", bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)
-class PersistenceFisherDistance(BaseEstimator, TransformerMixin):
+class WassersteinDistance(BaseEstimator, TransformerMixin):
"""
- This is a class for computing the persistence Fisher distance matrix from a list of persistence diagrams. The persistence Fisher distance is obtained by computing the original Fisher distance between the probability distributions associated to the persistence diagrams given by convolving them with a Gaussian kernel. See http://papers.nips.cc/paper/8205-persistence-fisher-kernel-a-riemannian-manifold-kernel-for-persistence-diagrams for more details.
+ This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams.
"""
- def __init__(self, bandwidth=1., kernel_approx=None):
+ def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01):
"""
- Constructor for the PersistenceFisherDistance class.
+ Constructor for the WassersteinDistance class.
Parameters:
- bandwidth (double): bandwidth of the Gaussian kernel used to turn persistence diagrams into probability distributions (default 1.).
- kernel_approx (class): kernel approximation class used to speed up computation (default None). Common kernel approximations classes can be found in the scikit-learn library (such as RBFSampler for instance).
+ order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`.
+ internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`.
+ mode (str): method for computing Wasserstein distance. Either "pot" or "hera".
+ delta (float): relative error 1+delta. Used only if mode == "hera".
"""
- self.bandwidth, self.kernel_approx = bandwidth, kernel_approx
+ self.order, self.internal_p, self.mode = order, internal_p, mode
+ self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein"
+ self.delta = delta
def fit(self, X, y=None):
"""
- Fit the PersistenceFisherDistance class on a list of persistence diagrams: persistence diagrams are stored in a numpy array called **diagrams** and the kernel approximation class (if not None) is applied on them.
+ Fit the WassersteinDistance class on a list of persistence diagrams: persistence diagrams are stored in a numpy array called **diagrams**.
Parameters:
X (list of n x 2 numpy arrays): input persistence diagrams.
@@ -318,12 +314,16 @@ class PersistenceFisherDistance(BaseEstimator, TransformerMixin):
def transform(self, X):
"""
- Compute all persistence Fisher distances between the persistence diagrams that were stored after calling the fit() method, and a given list of (possibly different) persistence diagrams.
+ Compute all Wasserstein distances between the persistence diagrams that were stored after calling the fit() method, and a given list of (possibly different) persistence diagrams.
Parameters:
X (list of n x 2 numpy arrays): input persistence diagrams.
Returns:
- numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence Fisher distances.
+ numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise Wasserstein distances.
"""
- return pairwise_persistence_diagram_distances(X, self.diagrams_, metric="persistence_fisher", bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)
+ if self.metric == "hera_wasserstein":
+ Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p, delta=self.delta)
+ else:
+ Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p)
+ return Xfit