summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/representations/metrics.py')
-rw-r--r--src/python/gudhi/representations/metrics.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index cf2e0879..142ddef1 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -350,23 +350,30 @@ class PersistenceFisherDistance(BaseEstimator, TransformerMixin):
"""
return _persistence_fisher_distance(diag1, diag2, bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)
+
class WassersteinDistance(BaseEstimator, TransformerMixin):
"""
This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams.
"""
- def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01, n_jobs=None):
+
+ def __init__(self, order=1, internal_p=np.inf, mode="hera", delta=0.01, n_jobs=None):
"""
Constructor for the WassersteinDistance class.
Parameters:
- order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`.
- internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`.
- mode (str): method for computing Wasserstein distance. Either "pot" or "hera".
+ order (int): exponent for Wasserstein, default value is 1., see :func:`gudhi.wasserstein.wasserstein_distance`.
+ internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is `np.inf`, see :func:`gudhi.wasserstein.wasserstein_distance`.
+ mode (str): method for computing Wasserstein distance. Either "pot" or "hera". Default set to "hera".
delta (float): relative error 1+delta. Used only if mode == "hera".
n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_persistence_diagram_distances` for details.
"""
self.order, self.internal_p, self.mode = order, internal_p, mode
- self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein"
+ if mode == "pot":
+ self.metric = "pot_wasserstein"
+ elif mode == "hera":
+ self.metric = "hera_wasserstein"
+ else:
+ raise NameError("Unknown mode. Current available values for mode are 'hera' and 'pot'")
self.delta = delta
self.n_jobs = n_jobs