From ef0f82ef2155440827e17c552abb49b509866fc7 Mon Sep 17 00:00:00 2001 From: mathieu Date: Thu, 13 Feb 2020 16:01:29 -0500 Subject: integrated hera --- src/python/gudhi/representations/metrics.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) (limited to 'src/python/gudhi/representations/metrics.py') diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py index cc788994..ed998603 100644 --- a/src/python/gudhi/representations/metrics.py +++ b/src/python/gudhi/representations/metrics.py @@ -10,7 +10,8 @@ import numpy as np from sklearn.base import BaseEstimator, TransformerMixin from sklearn.metrics import pairwise_distances -from gudhi.wasserstein import wasserstein_distance +from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance +from gudhi.hera import wasserstein_distance as hera_wasserstein_distance from .preprocessing import Padding try: @@ -117,8 +118,10 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwa if metric == "bottleneck": return pairwise_distances(XX, YY, metric=sklearn_wrapper(bottleneck_distance, **kwargs)) - elif metric == "wasserstein": - return pairwise_distances(XX, YY, metric=sklearn_wrapper(wasserstein_distance, **kwargs)) + elif metric == "wasserstein" or metric == "pot_wasserstein": + return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, **kwargs)) + elif metric == "hera_wasserstein": + return pairwise_distances(XX, YY, metric=sklearn_wrapper(hera_wasserstein_distance, **kwargs)) elif metric == "sliced_wasserstein": return pairwise_distances(XX, YY, metric=sklearn_wrapper(sliced_wasserstein_distance, **kwargs)) elif metric == "persistence_fisher": @@ -205,15 +208,19 @@ class WassersteinDistance(BaseEstimator, TransformerMixin): """ This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams. """ - def __init__(self, order=2, internal_p=2): + def __init__(self, order=2, internal_p=2, mode="pot", delta=0.0001): """ Constructor for the WassersteinDistance class. Parameters: order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`. internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`. + mode (str): method for computing Wasserstein distance. Either "pot" or "hera". + delta (float): relative error 1+delta. Used only if mode == "hera". """ - self.order, self.internal_p = order, internal_p + self.order, self.internal_p, self.mode = order, internal_p, mode + self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein" + self.delta = delta def fit(self, X, y=None): """ @@ -236,7 +243,11 @@ class WassersteinDistance(BaseEstimator, TransformerMixin): Returns: numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise Wasserstein distances. """ - return pairwise_persistence_diagram_distances(X, self.diagrams_, metric="wasserstein", order=self.order, internal_p=self.internal_p) + if self.metric == "hera_wasserstein": + Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p, delta=self.delta) + else: + Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p) + return Xfit class PersistenceFisherDistance(BaseEstimator, TransformerMixin): """ -- cgit v1.2.3