summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/metrics.py
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2020-04-28 13:28:10 -0400
committerMathieuCarriere <mathieu.carriere3@gmail.com>2020-04-28 13:28:10 -0400
commit87311ec2d59211320e763bc9bc531858b489ff7e (patch)
treeb9f6d88970315c7e8ba8015315692c270873de21 /src/python/gudhi/representations/metrics.py
parent51ca2370ae27bec052bb4fffafc8a718ac306264 (diff)
added call methods + other fixes
Diffstat (limited to 'src/python/gudhi/representations/metrics.py')
-rw-r--r--src/python/gudhi/representations/metrics.py97
1 files changed, 81 insertions, 16 deletions
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index 59440b1a..a4bf19a6 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -17,7 +17,7 @@ from .preprocessing import Padding
# Metrics ###################################
#############################################
-def sliced_wasserstein_distance(D1, D2, num_directions):
+def _sliced_wasserstein_distance(D1, D2, num_directions):
"""
This is a function for computing the sliced Wasserstein distance from two persistence diagrams. The Sliced Wasserstein distance is computed by projecting the persistence diagrams onto lines, comparing the projections with the 1-norm, and finally averaging over the lines. See http://proceedings.mlr.press/v70/carriere17a.html for more details.
@@ -42,7 +42,7 @@ def sliced_wasserstein_distance(D1, D2, num_directions):
L1 = np.sum(np.abs(A-B), axis=0)
return np.mean(L1)
-def compute_persistence_diagram_projections(X, num_directions):
+def _compute_persistence_diagram_projections(X, num_directions):
"""
This is a function for projecting the points of a list of persistence diagrams (as well as their diagonal projections) onto a fixed number of lines sampled uniformly on [-pi/2, pi/2]. This function can be used as a preprocessing step in order to speed up the running time for computing all pairwise sliced Wasserstein distances / kernel values on a list of persistence diagrams.
@@ -51,14 +51,14 @@ def compute_persistence_diagram_projections(X, num_directions):
num_directions (int): number of lines evenly sampled from [-pi/2,pi/2] in order to approximate and speed up the distance computation.
Returns:
- XX (list of n numpy arrays of shape (2*numx2)): list of projected persistence diagrams.
+ list of n numpy arrays of shape (2*numx2): list of projected persistence diagrams.
"""
thetas = np.linspace(-np.pi/2, np.pi/2, num=num_directions+1)[np.newaxis,:-1]
lines = np.concatenate([np.cos(thetas), np.sin(thetas)], axis=0)
XX = [np.vstack([np.matmul(D, lines), np.matmul(np.matmul(D, .5 * np.ones((2,2))), lines)]) for D in X]
return XX
-def sliced_wasserstein_distance_on_projections(D1, D2):
+def _sliced_wasserstein_distance_on_projections(D1, D2):
"""
This is a function for computing the sliced Wasserstein distance between two persistence diagrams that have already been projected onto some lines. It simply amounts to comparing the sorted projections with the 1-norm, and averaging over the lines. See http://proceedings.mlr.press/v70/carriere17a.html for more details.
@@ -76,7 +76,7 @@ def sliced_wasserstein_distance_on_projections(D1, D2):
L1 = np.sum(np.abs(A-B), axis=0)
return np.mean(L1)
-def persistence_fisher_distance(D1, D2, kernel_approx=None, bandwidth=1.):
+def _persistence_fisher_distance(D1, D2, kernel_approx=None, bandwidth=1.):
"""
This is a function for computing the persistence Fisher distance from two persistence diagrams. The persistence Fisher distance is obtained by computing the original Fisher distance between the probability distributions associated to the persistence diagrams given by convolving them with a Gaussian kernel. See http://papers.nips.cc/paper/8205-persistence-fisher-kernel-a-riemannian-manifold-kernel-for-persistence-diagrams for more details.
@@ -118,7 +118,7 @@ def persistence_fisher_distance(D1, D2, kernel_approx=None, bandwidth=1.):
vectorj = vectorj/vectorj_sum
return np.arccos( min(np.dot(np.sqrt(vectori), np.sqrt(vectorj)), 1.) )
-def sklearn_wrapper(metric, X, Y, **kwargs):
+def _sklearn_wrapper(metric, X, Y, **kwargs):
"""
This function is a wrapper for any metric between two persistence diagrams that takes two numpy arrays of shapes (nx2) and (mx2) as arguments.
"""
@@ -133,7 +133,7 @@ def sklearn_wrapper(metric, X, Y, **kwargs):
PAIRWISE_DISTANCE_FUNCTIONS = {
"wasserstein": hera_wasserstein_distance,
"hera_wasserstein": hera_wasserstein_distance,
- "persistence_fisher": persistence_fisher_distance,
+ "persistence_fisher": _persistence_fisher_distance,
}
def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwargs):
@@ -143,7 +143,7 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwa
Parameters:
X (list of n numpy arrays of shape (numx2)): first list of persistence diagrams.
Y (list of m numpy arrays of shape (numx2)): second list of persistence diagrams (optional). If None, pairwise distances are computed from the first list only.
- metric: distance to use. It can be either a string ("sliced_wasserstein", "wasserstein", "hera_wasserstein" (Wasserstein distance computed with Hera---note that Hera is also used for the default option "wasserstein"), "pot_wasserstein" (Wasserstein distance computed with POT), "bottleneck", "persistence_fisher") or a function taking two numpy arrays of shape (nx2) and (mx2) as inputs.
+ metric: distance to use. It can be either a string ("sliced_wasserstein", "wasserstein", "hera_wasserstein" (Wasserstein distance computed with Hera---note that Hera is also used for the default option "wasserstein"), "pot_wasserstein" (Wasserstein distance computed with POT), "bottleneck", "persistence_fisher") or a symmetric function taking two numpy arrays of shape (nx2) and (mx2) as inputs.
Returns:
numpy array of shape (nxm): distance matrix
@@ -153,25 +153,25 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwa
if metric == "bottleneck":
try:
from .. import bottleneck_distance
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(bottleneck_distance, X, Y, **kwargs))
+ return pairwise_distances(XX, YY, metric=_sklearn_wrapper(bottleneck_distance, X, Y, **kwargs))
except ImportError:
print("Gudhi built without CGAL")
raise
elif metric == "pot_wasserstein":
try:
from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, X, Y, **kwargs))
+ return pairwise_distances(XX, YY, metric=_sklearn_wrapper(pot_wasserstein_distance, X, Y, **kwargs))
except ImportError:
print("POT (Python Optimal Transport) is not installed. Please install POT or use metric='wasserstein' or metric='hera_wasserstein'")
raise
elif metric == "sliced_wasserstein":
- Xproj = compute_persistence_diagram_projections(X, **kwargs)
- Yproj = None if Y is None else compute_persistence_diagram_projections(Y, **kwargs)
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(sliced_wasserstein_distance_on_projections, Xproj, Yproj))
+ Xproj = _compute_persistence_diagram_projections(X, **kwargs)
+ Yproj = None if Y is None else _compute_persistence_diagram_projections(Y, **kwargs)
+ return pairwise_distances(XX, YY, metric=_sklearn_wrapper(_sliced_wasserstein_distance_on_projections, Xproj, Yproj))
elif type(metric) == str:
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(PAIRWISE_DISTANCE_FUNCTIONS[metric], X, Y, **kwargs))
+ return pairwise_distances(XX, YY, metric=_sklearn_wrapper(PAIRWISE_DISTANCE_FUNCTIONS[metric], X, Y, **kwargs))
else:
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(metric, X, Y, **kwargs))
+ return pairwise_distances(XX, YY, metric=_sklearn_wrapper(metric, X, Y, **kwargs))
class SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
"""
@@ -209,6 +209,19 @@ class SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
"""
return pairwise_persistence_diagram_distances(X, self.diagrams_, metric="sliced_wasserstein", num_directions=self.num_directions)
+ def __call__(self, diag1, diag2):
+ """
+ Apply SlicedWassersteinDistance on a single pair of persistence diagrams and outputs the result.
+
+ Parameters:
+ diag1 (n x 2 numpy array): first input persistence diagram.
+ diag2 (n x 2 numpy array): second input persistence diagram.
+
+ Returns:
+ float: sliced Wasserstein distance.
+ """
+ return _sliced_wasserstein_distance(diag1, diag2, num_directions=self.num_directions)
+
class BottleneckDistance(BaseEstimator, TransformerMixin):
"""
This is a class for computing the bottleneck distance matrix from a list of persistence diagrams.
@@ -246,6 +259,24 @@ class BottleneckDistance(BaseEstimator, TransformerMixin):
Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric="bottleneck", e=self.epsilon)
return Xfit
+ def __call__(self, diag1, diag2):
+ """
+ Apply BottleneckDistance on a single pair of persistence diagrams and outputs the result.
+
+ Parameters:
+ diag1 (n x 2 numpy array): first input persistence diagram.
+ diag2 (n x 2 numpy array): second input persistence diagram.
+
+ Returns:
+ float: bottleneck distance.
+ """
+ try:
+ from .. import bottleneck_distance
+ return bottleneck_distance(diag1, diag2, e=self.epsilon)
+ except ImportError:
+ print("Gudhi built without CGAL")
+ raise
+
class PersistenceFisherDistance(BaseEstimator, TransformerMixin):
"""
This is a class for computing the persistence Fisher distance matrix from a list of persistence diagrams. The persistence Fisher distance is obtained by computing the original Fisher distance between the probability distributions associated to the persistence diagrams given by convolving them with a Gaussian kernel. See http://papers.nips.cc/paper/8205-persistence-fisher-kernel-a-riemannian-manifold-kernel-for-persistence-diagrams for more details.
@@ -283,6 +314,19 @@ class PersistenceFisherDistance(BaseEstimator, TransformerMixin):
"""
return pairwise_persistence_diagram_distances(X, self.diagrams_, metric="persistence_fisher", bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)
+ def __call__(self, diag1, diag2):
+ """
+ Apply PersistenceFisherDistance on a single pair of persistence diagrams and outputs the result.
+
+ Parameters:
+ diag1 (n x 2 numpy array): first input persistence diagram.
+ diag2 (n x 2 numpy array): second input persistence diagram.
+
+ Returns:
+ float: persistence Fisher distance.
+ """
+ return _persistence_fisher_distance(diag1, diag2, bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)
+
class WassersteinDistance(BaseEstimator, TransformerMixin):
"""
This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams.
@@ -325,5 +369,26 @@ class WassersteinDistance(BaseEstimator, TransformerMixin):
if self.metric == "hera_wasserstein":
Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p, delta=self.delta)
else:
- Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p)
+ Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p, matching=False)
return Xfit
+
+ def __call__(self, diag1, diag2):
+ """
+ Apply WassersteinDistance on a single pair of persistence diagrams and outputs the result.
+
+ Parameters:
+ diag1 (n x 2 numpy array): first input persistence diagram.
+ diag2 (n x 2 numpy array): second input persistence diagram.
+
+ Returns:
+ float: Wasserstein distance.
+ """
+ if self.metric == "hera_wasserstein":
+ return hera_wasserstein_distance(diag1, diag2, order=self.order, internal_p=self.internal_p, delta=self.delta)
+ else:
+ try:
+ from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance
+ return pot_wasserstein_distance(diag1, diag2, order=self.order, internal_p=self.internal_p, matching=False)
+ except ImportError:
+ print("POT (Python Optimal Transport) is not installed. Please install POT or use metric='wasserstein' or metric='hera_wasserstein'")
+ raise