summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-06-29 09:53:31 +0200
committertlacombe <lacombe1993@gmail.com>2020-06-29 09:53:31 +0200
commitb0c4b1920faa512b25e7c34cbe83c8e625ac3613 (patch)
tree548af42b0327df5937202a43980006a2ee6a3067 /src
parent4fe1d853d32846b8f34699c7121b4d7edfeb48ac (diff)
update default param for representation.metrics
Diffstat (limited to 'src')
-rw-r--r--src/python/gudhi/representations/metrics.py16
1 files changed, 11 insertions, 5 deletions
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index 8a32f7e9..3e8487c4 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -328,22 +328,28 @@ class PersistenceFisherDistance(BaseEstimator, TransformerMixin):
"""
return _persistence_fisher_distance(diag1, diag2, bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)
+
class WassersteinDistance(BaseEstimator, TransformerMixin):
"""
This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams.
"""
- def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01):
+ def __init__(self, order=1, internal_p=np.inf, mode="hera", delta=0.01):
"""
Constructor for the WassersteinDistance class.
Parameters:
- order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`.
- internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`.
- mode (str): method for computing Wasserstein distance. Either "pot" or "hera".
+ order (int): exponent for Wasserstein, default value is 1., see :func:`gudhi.wasserstein.wasserstein_distance`.
+ internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is `np.inf`, see :func:`gudhi.wasserstein.wasserstein_distance`.
+ mode (str): method for computing Wasserstein distance. Either "pot" or "hera". Default set to "hera".
delta (float): relative error 1+delta. Used only if mode == "hera".
"""
self.order, self.internal_p, self.mode = order, internal_p, mode
- self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein"
+ if mode == "pot":
+ self.metric = "pot_wasserstein"
+ elif mode == "hera":
+ self.metric = "hera_wasserstein"
+ else:
+ raise NameError("Unknown mode. Current available values for mode are 'hera' and 'pot'")
self.delta = delta
def fit(self, X, y=None):