From 85ceea9512634a62664208cd2d0f1ce48bafa171 Mon Sep 17 00:00:00 2001 From: mathieu Date: Thu, 16 Jan 2020 17:02:55 -0500 Subject: added wasserstein class --- src/python/gudhi/representations/metrics.py | 59 +++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) (limited to 'src/python/gudhi') diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py index 5f9ec6ab..290c1d07 100644 --- a/src/python/gudhi/representations/metrics.py +++ b/src/python/gudhi/representations/metrics.py @@ -10,6 +10,7 @@ import numpy as np from sklearn.base import BaseEstimator, TransformerMixin from sklearn.metrics import pairwise_distances +from gudhi.wasserstein import wasserstein_distance try: from .. import bottleneck_distance USE_GUDHI = True @@ -145,6 +146,64 @@ class BottleneckDistance(BaseEstimator, TransformerMixin): return Xfit +class WassersteinDistance(BaseEstimator, TransformerMixin): + """ + This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams. + """ + def __init__(self, order=2, internal_p=2): + """ + Constructor for the WassersteinDistance class. + + Parameters: + order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`. + internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`. + """ + self.order, self.internal_p = order, internal_p + + def fit(self, X, y=None): + """ + Fit the WassersteinDistance class on a list of persistence diagrams: persistence diagrams are stored in a numpy array called **diagrams**. + + Parameters: + X (list of n x 2 numpy arrays): input persistence diagrams. + y (n x 1 array): persistence diagram labels (unused). + """ + self.diagrams_ = X + return self + + def transform(self, X): + """ + Compute all Wasserstein distances between the persistence diagrams that were stored after calling the fit() method, and a given list of (possibly different) persistence diagrams. + + Parameters: + X (list of n x 2 numpy arrays): input persistence diagrams. + + Returns: + numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise Wasserstein distances. + """ + num_diag1 = len(X) + + #if len(self.diagrams_) == len(X) and np.all([np.array_equal(self.diagrams_[i], X[i]) for i in range(len(X))]): + if X is self.diagrams_: + matrix = np.zeros((num_diag1, num_diag1)) + + for i in range(num_diag1): + for j in range(i+1, num_diag1): + matrix[i,j] = wasserstein_distance(X[i], X[j], self.order, self.internal_p) + matrix[j,i] = matrix[i,j] + + else: + num_diag2 = len(self.diagrams_) + matrix = np.zeros((num_diag1, num_diag2)) + + for i in range(num_diag1): + for j in range(num_diag2): + matrix[i,j] = wasserstein_distance(X[i], self.diagrams_[j], self.order, self.internal_p) + + Xfit = matrix + + return Xfit + class PersistenceFisherDistance(BaseEstimator, TransformerMixin): """ This is a class for computing the persistence Fisher distance matrix from a list of persistence diagrams. The persistence Fisher distance is obtained by computing the original Fisher distance between the probability distributions associated to the persistence diagrams given by convolving them with a Gaussian kernel. See http://papers.nips.cc/paper/8205-persistence-fisher-kernel-a-riemannian-manifold-kernel-for-persistence-diagrams for more details. -- cgit v1.2.3