summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations
diff options
context:
space:
mode:
authormathieu <mathieu.carriere3@gmail.com>2020-03-11 15:35:37 -0400
committermathieu <mathieu.carriere3@gmail.com>2020-03-11 15:35:37 -0400
commit25e40a52ec7bc9e1bfe418fb1aa16e2a06994d1b (patch)
tree913df06025f5ff18fe6b9e2b1bf01996cded0fd5 /src/python/gudhi/representations
parenta47ace987876cb52351ae9223d335629aedbd71e (diff)
new fixes
Diffstat (limited to 'src/python/gudhi/representations')
-rw-r--r--src/python/gudhi/representations/metrics.py63
1 files changed, 50 insertions, 13 deletions
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index 0659b457..f913f1fc 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -19,7 +19,7 @@ from .preprocessing import Padding
def sliced_wasserstein_distance(D1, D2, num_directions):
"""
- This is a function for computing the sliced Wasserstein distance from two persistence diagrams. The Sliced Wasserstein distance is computed by projecting the persistence diagrams onto lines, comparing the projections with the 1-norm, and finally integrating over all possible lines. See http://proceedings.mlr.press/v70/carriere17a.html for more details.
+ This is a function for computing the sliced Wasserstein distance from two persistence diagrams. The Sliced Wasserstein distance is computed by projecting the persistence diagrams onto lines, comparing the projections with the 1-norm, and finally averaging over the lines. See http://proceedings.mlr.press/v70/carriere17a.html for more details.
:param D1: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate).
:param D2: (m x 2) numpy.array encoding the second diagram.
:param num_directions: number of lines evenly sampled from [-pi/2,pi/2] in order to approximate and speed up the distance computation.
@@ -39,6 +39,34 @@ def sliced_wasserstein_distance(D1, D2, num_directions):
L1 = np.sum(np.abs(A-B), axis=0)
return np.mean(L1)
+def compute_persistence_diagram_projections(X, num_directions):
+ """
+ This is a function for projecting the points of a list of persistence diagrams (as well as their diagonal projections) onto a fixed number of lines sampled uniformly on [-pi/2, pi/2]. This function can be used as a preprocessing step in order to speed up the running time for computing all pairwise sliced Wasserstein distances / kernel values on a list of persistence diagrams.
+ :param X: list of persistence diagrams.
+ :param num_directions: number of lines evenly sampled from [-pi/2,pi/2] in order to approximate and speed up the distance computation.
+ :returns: list of projected persistence diagrams.
+ :rtype: float
+ """
+ thetas = np.linspace(-np.pi/2, np.pi/2, num=num_directions+1)[np.newaxis,:-1]
+ lines = np.concatenate([np.cos(thetas), np.sin(thetas)], axis=0)
+ XX = [np.vstack([np.matmul(D, lines), np.matmul(np.matmul(D, .5 * np.ones((2,2))), lines)]) for D in X]
+ return XX
+
+def sliced_wasserstein_distance_on_projections(D1, D2):
+ """
+ This is a function for computing the sliced Wasserstein distance between two persistence diagrams that have already been projected onto some lines. It simply amounts to comparing the sorted projections with the 1-norm, and averaging over the lines. See http://proceedings.mlr.press/v70/carriere17a.html for more details.
+ :param D1: (2n x number_of_lines) numpy.array containing the n projected points of the first diagram, and the n projections of their diagonal projections.
+ :param D2: (2m x number_of_lines) numpy.array containing the m projected points of the second diagram, and the m projections of their diagonal projections.
+ :returns: the sliced Wasserstein distance between the projected persistence diagrams.
+ :rtype: float
+ """
+ lim1, lim2 = int(len(D1)/2), int(len(D2)/2)
+ approx1, approx_diag1, approx2, approx_diag2 = D1[:lim1], D1[lim1:], D2[:lim2], D2[lim2:]
+ A = np.sort(np.concatenate([approx1, approx_diag2], axis=0), axis=0)
+ B = np.sort(np.concatenate([approx2, approx_diag1], axis=0), axis=0)
+ L1 = np.sum(np.abs(A-B), axis=0)
+ return np.mean(L1)
+
def persistence_fisher_distance(D1, D2, kernel_approx=None, bandwidth=1.):
"""
This is a function for computing the persistence Fisher distance from two persistence diagrams. The persistence Fisher distance is obtained by computing the original Fisher distance between the probability distributions associated to the persistence diagrams given by convolving them with a Gaussian kernel. See http://papers.nips.cc/paper/8205-persistence-fisher-kernel-a-riemannian-manifold-kernel-for-persistence-diagrams for more details.
@@ -90,31 +118,43 @@ def sklearn_wrapper(metric, X, Y, **kwargs):
return metric(X[int(a[0])], Y[int(b[0])], **kwargs)
return flat_metric
+PAIRWISE_DISTANCE_FUNCTIONS = {
+ "wasserstein": hera_wasserstein_distance,
+ "hera_wasserstein": hera_wasserstein_distance,
+ "persistence_fisher": persistence_fisher_distance,
+}
+
def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwargs):
"""
This function computes the distance matrix between two lists of persistence diagrams given as numpy arrays of shape (nx2).
:param X: first list of persistence diagrams.
:param Y: second list of persistence diagrams (optional). If None, pairwise distances are computed from the first list only.
- :param metric: distance to use. It can be either a string ("sliced_wasserstein", "wasserstein", "bottleneck", "persistence_fisher") or a function taking two numpy arrays of shape (nx2) and (mx2) as inputs.
+ :param metric: distance to use. It can be either a string ("sliced_wasserstein", "wasserstein", "hera_wasserstein" (Wasserstein distance computed with Hera---note that Hera is also used for the default option "wasserstein"), "pot_wasserstein" (Wasserstein distance computed with POT), "bottleneck", "persistence_fisher") or a function taking two numpy arrays of shape (nx2) and (mx2) as inputs.
:returns: distance matrix, i.e., numpy array of shape (num diagrams 1 x num diagrams 2)
:rtype: float
"""
XX = np.reshape(np.arange(len(X)), [-1,1])
YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1])
if metric == "bottleneck":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(bottleneck_distance, X, Y, **kwargs))
+ try:
+ from .. import bottleneck_distance
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(bottleneck_distance, X, Y, **kwargs))
+ except ImportError:
+ print("Gudhi built without CGAL")
+ raise
elif metric == "pot_wasserstein":
try:
from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance
return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, X, Y, **kwargs))
except ImportError:
- print("Gudhi built without POT")
- elif metric == "wasserstein" or metric == "hera_wasserstein":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(hera_wasserstein_distance, X, Y, **kwargs))
+ print("Gudhi built without POT. Please install POT or use metric='wasserstein' or metric='hera_wasserstein'")
+ raise
elif metric == "sliced_wasserstein":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(sliced_wasserstein_distance, X, Y, **kwargs))
- elif metric == "persistence_fisher":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(persistence_fisher_distance, X, Y, **kwargs))
+ Xproj = compute_persistence_diagram_projections(X, **kwargs)
+ Yproj = None if Y is None else compute_persistence_diagram_projections(Y, **kwargs)
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(sliced_wasserstein_distance_on_projections, Xproj, Yproj))
+ elif type(metric) == str:
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(PAIRWISE_DISTANCE_FUNCTIONS[metric], X, Y, **kwargs))
else:
return pairwise_distances(XX, YY, metric=sklearn_wrapper(metric, X, Y, **kwargs))
@@ -188,10 +228,7 @@ class BottleneckDistance(BaseEstimator, TransformerMixin):
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise bottleneck distances.
"""
- try:
- Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric="bottleneck", e=self.epsilon)
- except ImportError:
- print("Gudhi built without CGAL")
+ Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric="bottleneck", e=self.epsilon)
return Xfit
class WassersteinDistance(BaseEstimator, TransformerMixin):