summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorVincent Rouvreau <10407034+VincentRouvreau@users.noreply.github.com>2020-07-01 03:09:08 -0700
committerGitHub <noreply@github.com>2020-07-01 03:09:08 -0700
commit8d316e831c6af51efb9c362a5b203528a9fd3b15 (patch)
treed06a69eaea859b36a15a90812d6c2de7f7018633 /src
parentd58fe553c806874d723c7e6d4fee368cc308e474 (diff)
parent0b4de61a18bc30f66a7fb45cc246cff2f55ba1a1 (diff)
Merge pull request #349 from tlacombe/fix342
fix #342
Diffstat (limited to 'src')
-rw-r--r--src/python/gudhi/representations/metrics.py17
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py10
2 files changed, 17 insertions, 10 deletions
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index cf2e0879..142ddef1 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -350,23 +350,30 @@ class PersistenceFisherDistance(BaseEstimator, TransformerMixin):
"""
return _persistence_fisher_distance(diag1, diag2, bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)
+
class WassersteinDistance(BaseEstimator, TransformerMixin):
"""
This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams.
"""
- def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01, n_jobs=None):
+
+ def __init__(self, order=1, internal_p=np.inf, mode="hera", delta=0.01, n_jobs=None):
"""
Constructor for the WassersteinDistance class.
Parameters:
- order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`.
- internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`.
- mode (str): method for computing Wasserstein distance. Either "pot" or "hera".
+ order (int): exponent for Wasserstein, default value is 1., see :func:`gudhi.wasserstein.wasserstein_distance`.
+ internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is `np.inf`, see :func:`gudhi.wasserstein.wasserstein_distance`.
+ mode (str): method for computing Wasserstein distance. Either "pot" or "hera". Default set to "hera".
delta (float): relative error 1+delta. Used only if mode == "hera".
n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_persistence_diagram_distances` for details.
"""
self.order, self.internal_p, self.mode = order, internal_p, mode
- self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein"
+ if mode == "pot":
+ self.metric = "pot_wasserstein"
+ elif mode == "hera":
+ self.metric = "hera_wasserstein"
+ else:
+ raise NameError("Unknown mode. Current available values for mode are 'hera' and 'pot'")
self.delta = delta
self.n_jobs = n_jobs
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 89ecab1c..b37d30bb 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -73,8 +73,8 @@ def _perstot_autodiff(X, order, internal_p):
def _perstot(X, order, internal_p, enable_autodiff):
'''
:param X: (n x 2) numpy.array (points of a given diagram).
- :param order: exponent for Wasserstein. Default value is 2.
- :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm).
+ :param order: exponent for Wasserstein.
+ :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2).
:param enable_autodiff: If X is torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
transparent to automatic differentiation.
:type enable_autodiff: bool
@@ -88,7 +88,7 @@ def _perstot(X, order, internal_p, enable_autodiff):
return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order)
-def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_autodiff=False):
+def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False):
'''
:param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
(i.e. with infinite coordinate).
@@ -96,9 +96,9 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a
:param matching: if True, computes and returns the optimal matching between X and Y, encoded as
a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to
the j-th point in Y, with the convention (-1) represents the diagonal.
- :param order: exponent for Wasserstein; Default value is 2.
+ :param order: exponent for Wasserstein; Default value is 1.
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
- Default value is 2 (Euclidean norm).
+ Default value is `np.inf`.
:param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible
with `matching=True`.