summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormathieu <mathieu.carriere3@gmail.com>2020-02-13 16:01:29 -0500
committermathieu <mathieu.carriere3@gmail.com>2020-02-13 16:01:29 -0500
commitef0f82ef2155440827e17c552abb49b509866fc7 (patch)
treec012fa25aed54c23bb3f4c6cc5c271af7d1982c9
parent2f2db197a38e45ac4fe01dec0c029171c251029b (diff)
integrated hera
-rwxr-xr-xsrc/python/example/diagram_vectorizations_distances_kernels.py7
-rw-r--r--src/python/gudhi/representations/metrics.py23
2 files changed, 23 insertions, 7 deletions
diff --git a/src/python/example/diagram_vectorizations_distances_kernels.py b/src/python/example/diagram_vectorizations_distances_kernels.py
index 66c32cc2..6352d2b5 100755
--- a/src/python/example/diagram_vectorizations_distances_kernels.py
+++ b/src/python/example/diagram_vectorizations_distances_kernels.py
@@ -117,7 +117,12 @@ X = SW.fit(diags)
Y = SW.transform(diags2)
print("SW kernel is " + str(Y[0][0]))
-W = WassersteinDistance(order=2, internal_p=2)
+W = WassersteinDistance(order=2, internal_p=2, mode="pot")
+X = W.fit(diags)
+Y = W.transform(diags2)
+print("Wasserstein distance is " + str(Y[0][0]))
+
+W = WassersteinDistance(order=2, internal_p=2, mode="hera", delta=0.0001)
X = W.fit(diags)
Y = W.transform(diags2)
print("Wasserstein distance is " + str(Y[0][0]))
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index cc788994..ed998603 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -10,7 +10,8 @@
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import pairwise_distances
-from gudhi.wasserstein import wasserstein_distance
+from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance
+from gudhi.hera import wasserstein_distance as hera_wasserstein_distance
from .preprocessing import Padding
try:
@@ -117,8 +118,10 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwa
if metric == "bottleneck":
return pairwise_distances(XX, YY, metric=sklearn_wrapper(bottleneck_distance, **kwargs))
- elif metric == "wasserstein":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(wasserstein_distance, **kwargs))
+ elif metric == "wasserstein" or metric == "pot_wasserstein":
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, **kwargs))
+ elif metric == "hera_wasserstein":
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(hera_wasserstein_distance, **kwargs))
elif metric == "sliced_wasserstein":
return pairwise_distances(XX, YY, metric=sklearn_wrapper(sliced_wasserstein_distance, **kwargs))
elif metric == "persistence_fisher":
@@ -205,15 +208,19 @@ class WassersteinDistance(BaseEstimator, TransformerMixin):
"""
This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams.
"""
- def __init__(self, order=2, internal_p=2):
+ def __init__(self, order=2, internal_p=2, mode="pot", delta=0.0001):
"""
Constructor for the WassersteinDistance class.
Parameters:
order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`.
internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`.
+ mode (str): method for computing Wasserstein distance. Either "pot" or "hera".
+ delta (float): relative error 1+delta. Used only if mode == "hera".
"""
- self.order, self.internal_p = order, internal_p
+ self.order, self.internal_p, self.mode = order, internal_p, mode
+ self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein"
+ self.delta = delta
def fit(self, X, y=None):
"""
@@ -236,7 +243,11 @@ class WassersteinDistance(BaseEstimator, TransformerMixin):
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise Wasserstein distances.
"""
- return pairwise_persistence_diagram_distances(X, self.diagrams_, metric="wasserstein", order=self.order, internal_p=self.internal_p)
+ if self.metric == "hera_wasserstein":
+ Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p, delta=self.delta)
+ else:
+ Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p)
+ return Xfit
class PersistenceFisherDistance(BaseEstimator, TransformerMixin):
"""