diff options
Diffstat (limited to 'src/python/gudhi/representations/metrics.py')
-rw-r--r-- | src/python/gudhi/representations/metrics.py | 60 |
1 files changed, 30 insertions, 30 deletions
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py index e2c30f8c..59440b1a 100644 --- a/src/python/gudhi/representations/metrics.py +++ b/src/python/gudhi/representations/metrics.py @@ -246,27 +246,23 @@ class BottleneckDistance(BaseEstimator, TransformerMixin): Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric="bottleneck", e=self.epsilon) return Xfit -class WassersteinDistance(BaseEstimator, TransformerMixin): +class PersistenceFisherDistance(BaseEstimator, TransformerMixin): """ - This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams. + This is a class for computing the persistence Fisher distance matrix from a list of persistence diagrams. The persistence Fisher distance is obtained by computing the original Fisher distance between the probability distributions associated to the persistence diagrams given by convolving them with a Gaussian kernel. See http://papers.nips.cc/paper/8205-persistence-fisher-kernel-a-riemannian-manifold-kernel-for-persistence-diagrams for more details. """ - def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01): + def __init__(self, bandwidth=1., kernel_approx=None): """ - Constructor for the WassersteinDistance class. + Constructor for the PersistenceFisherDistance class. Parameters: - order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`. - internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`. - mode (str): method for computing Wasserstein distance. Either "pot" or "hera". - delta (float): relative error 1+delta. Used only if mode == "hera". + bandwidth (double): bandwidth of the Gaussian kernel used to turn persistence diagrams into probability distributions (default 1.). + kernel_approx (class): kernel approximation class used to speed up computation (default None). Common kernel approximations classes can be found in the scikit-learn library (such as RBFSampler for instance). """ - self.order, self.internal_p, self.mode = order, internal_p, mode - self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein" - self.delta = delta + self.bandwidth, self.kernel_approx = bandwidth, kernel_approx def fit(self, X, y=None): """ - Fit the WassersteinDistance class on a list of persistence diagrams: persistence diagrams are stored in a numpy array called **diagrams**. + Fit the PersistenceFisherDistance class on a list of persistence diagrams: persistence diagrams are stored in a numpy array called **diagrams** and the kernel approximation class (if not None) is applied on them. Parameters: X (list of n x 2 numpy arrays): input persistence diagrams. @@ -277,37 +273,37 @@ class WassersteinDistance(BaseEstimator, TransformerMixin): def transform(self, X): """ - Compute all Wasserstein distances between the persistence diagrams that were stored after calling the fit() method, and a given list of (possibly different) persistence diagrams. + Compute all persistence Fisher distances between the persistence diagrams that were stored after calling the fit() method, and a given list of (possibly different) persistence diagrams. Parameters: X (list of n x 2 numpy arrays): input persistence diagrams. Returns: - numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise Wasserstein distances. + numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence Fisher distances. """ - if self.metric == "hera_wasserstein": - Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p, delta=self.delta) - else: - Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p) - return Xfit + return pairwise_persistence_diagram_distances(X, self.diagrams_, metric="persistence_fisher", bandwidth=self.bandwidth, kernel_approx=self.kernel_approx) -class PersistenceFisherDistance(BaseEstimator, TransformerMixin): +class WassersteinDistance(BaseEstimator, TransformerMixin): """ - This is a class for computing the persistence Fisher distance matrix from a list of persistence diagrams. The persistence Fisher distance is obtained by computing the original Fisher distance between the probability distributions associated to the persistence diagrams given by convolving them with a Gaussian kernel. See http://papers.nips.cc/paper/8205-persistence-fisher-kernel-a-riemannian-manifold-kernel-for-persistence-diagrams for more details. + This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams. """ - def __init__(self, bandwidth=1., kernel_approx=None): + def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01): """ - Constructor for the PersistenceFisherDistance class. + Constructor for the WassersteinDistance class. Parameters: - bandwidth (double): bandwidth of the Gaussian kernel used to turn persistence diagrams into probability distributions (default 1.). - kernel_approx (class): kernel approximation class used to speed up computation (default None). Common kernel approximations classes can be found in the scikit-learn library (such as RBFSampler for instance). + order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`. + internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`. + mode (str): method for computing Wasserstein distance. Either "pot" or "hera". + delta (float): relative error 1+delta. Used only if mode == "hera". """ - self.bandwidth, self.kernel_approx = bandwidth, kernel_approx + self.order, self.internal_p, self.mode = order, internal_p, mode + self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein" + self.delta = delta def fit(self, X, y=None): """ - Fit the PersistenceFisherDistance class on a list of persistence diagrams: persistence diagrams are stored in a numpy array called **diagrams** and the kernel approximation class (if not None) is applied on them. + Fit the WassersteinDistance class on a list of persistence diagrams: persistence diagrams are stored in a numpy array called **diagrams**. Parameters: X (list of n x 2 numpy arrays): input persistence diagrams. @@ -318,12 +314,16 @@ class PersistenceFisherDistance(BaseEstimator, TransformerMixin): def transform(self, X): """ - Compute all persistence Fisher distances between the persistence diagrams that were stored after calling the fit() method, and a given list of (possibly different) persistence diagrams. + Compute all Wasserstein distances between the persistence diagrams that were stored after calling the fit() method, and a given list of (possibly different) persistence diagrams. Parameters: X (list of n x 2 numpy arrays): input persistence diagrams. Returns: - numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence Fisher distances. + numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise Wasserstein distances. """ - return pairwise_persistence_diagram_distances(X, self.diagrams_, metric="persistence_fisher", bandwidth=self.bandwidth, kernel_approx=self.kernel_approx) + if self.metric == "hera_wasserstein": + Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p, delta=self.delta) + else: + Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p) + return Xfit |