From ef0f82ef2155440827e17c552abb49b509866fc7 Mon Sep 17 00:00:00 2001 From: mathieu Date: Thu, 13 Feb 2020 16:01:29 -0500 Subject: integrated hera --- .../diagram_vectorizations_distances_kernels.py | 7 ++++++- src/python/gudhi/representations/metrics.py | 23 ++++++++++++++++------ 2 files changed, 23 insertions(+), 7 deletions(-) (limited to 'src') diff --git a/src/python/example/diagram_vectorizations_distances_kernels.py b/src/python/example/diagram_vectorizations_distances_kernels.py index 66c32cc2..6352d2b5 100755 --- a/src/python/example/diagram_vectorizations_distances_kernels.py +++ b/src/python/example/diagram_vectorizations_distances_kernels.py @@ -117,7 +117,12 @@ X = SW.fit(diags) Y = SW.transform(diags2) print("SW kernel is " + str(Y[0][0])) -W = WassersteinDistance(order=2, internal_p=2) +W = WassersteinDistance(order=2, internal_p=2, mode="pot") +X = W.fit(diags) +Y = W.transform(diags2) +print("Wasserstein distance is " + str(Y[0][0])) + +W = WassersteinDistance(order=2, internal_p=2, mode="hera", delta=0.0001) X = W.fit(diags) Y = W.transform(diags2) print("Wasserstein distance is " + str(Y[0][0])) diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py index cc788994..ed998603 100644 --- a/src/python/gudhi/representations/metrics.py +++ b/src/python/gudhi/representations/metrics.py @@ -10,7 +10,8 @@ import numpy as np from sklearn.base import BaseEstimator, TransformerMixin from sklearn.metrics import pairwise_distances -from gudhi.wasserstein import wasserstein_distance +from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance +from gudhi.hera import wasserstein_distance as hera_wasserstein_distance from .preprocessing import Padding try: @@ -117,8 +118,10 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwa if metric == "bottleneck": return pairwise_distances(XX, YY, metric=sklearn_wrapper(bottleneck_distance, **kwargs)) - elif metric == "wasserstein": - return pairwise_distances(XX, YY, metric=sklearn_wrapper(wasserstein_distance, **kwargs)) + elif metric == "wasserstein" or metric == "pot_wasserstein": + return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, **kwargs)) + elif metric == "hera_wasserstein": + return pairwise_distances(XX, YY, metric=sklearn_wrapper(hera_wasserstein_distance, **kwargs)) elif metric == "sliced_wasserstein": return pairwise_distances(XX, YY, metric=sklearn_wrapper(sliced_wasserstein_distance, **kwargs)) elif metric == "persistence_fisher": @@ -205,15 +208,19 @@ class WassersteinDistance(BaseEstimator, TransformerMixin): """ This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams. """ - def __init__(self, order=2, internal_p=2): + def __init__(self, order=2, internal_p=2, mode="pot", delta=0.0001): """ Constructor for the WassersteinDistance class. Parameters: order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`. internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`. + mode (str): method for computing Wasserstein distance. Either "pot" or "hera". + delta (float): relative error 1+delta. Used only if mode == "hera". """ - self.order, self.internal_p = order, internal_p + self.order, self.internal_p, self.mode = order, internal_p, mode + self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein" + self.delta = delta def fit(self, X, y=None): """ @@ -236,7 +243,11 @@ class WassersteinDistance(BaseEstimator, TransformerMixin): Returns: numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise Wasserstein distances. """ - return pairwise_persistence_diagram_distances(X, self.diagrams_, metric="wasserstein", order=self.order, internal_p=self.internal_p) + if self.metric == "hera_wasserstein": + Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p, delta=self.delta) + else: + Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p) + return Xfit class PersistenceFisherDistance(BaseEstimator, TransformerMixin): """ -- cgit v1.2.3