summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormathieu <mathieu.carriere3@gmail.com>2020-01-16 17:02:55 -0500
committermathieu <mathieu.carriere3@gmail.com>2020-01-16 17:02:55 -0500
commit85ceea9512634a62664208cd2d0f1ce48bafa171 (patch)
tree16ece19cf25d11fd875b2599f72c09007e469b98
parentcabc43b34723efa7640313348b844eabe9971e38 (diff)
added wasserstein class
-rwxr-xr-xsrc/python/example/diagram_vectorizations_distances_kernels.py7
-rw-r--r--src/python/gudhi/representations/metrics.py59
2 files changed, 65 insertions, 1 deletions
diff --git a/src/python/example/diagram_vectorizations_distances_kernels.py b/src/python/example/diagram_vectorizations_distances_kernels.py
index 119072eb..66c32cc2 100755
--- a/src/python/example/diagram_vectorizations_distances_kernels.py
+++ b/src/python/example/diagram_vectorizations_distances_kernels.py
@@ -9,7 +9,7 @@ from gudhi.representations import DiagramSelector, Clamping, Landscape, Silhouet
TopologicalVector, DiagramScaler, BirthPersistenceTransform,\
PersistenceImage, PersistenceWeightedGaussianKernel, Entropy, \
PersistenceScaleSpaceKernel, SlicedWassersteinDistance,\
- SlicedWassersteinKernel, BottleneckDistance, PersistenceFisherKernel
+ SlicedWassersteinKernel, BottleneckDistance, WassersteinDistance, PersistenceFisherKernel
D = np.array([[0.,4.],[1.,2.],[3.,8.],[6.,8.], [0., np.inf], [5., np.inf]])
diags = [D]
@@ -117,6 +117,11 @@ X = SW.fit(diags)
Y = SW.transform(diags2)
print("SW kernel is " + str(Y[0][0]))
+W = WassersteinDistance(order=2, internal_p=2)
+X = W.fit(diags)
+Y = W.transform(diags2)
+print("Wasserstein distance is " + str(Y[0][0]))
+
W = BottleneckDistance(epsilon=.001)
X = W.fit(diags)
Y = W.transform(diags2)
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index 5f9ec6ab..290c1d07 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -10,6 +10,7 @@
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import pairwise_distances
+from gudhi.wasserstein import wasserstein_distance
try:
from .. import bottleneck_distance
USE_GUDHI = True
@@ -145,6 +146,64 @@ class BottleneckDistance(BaseEstimator, TransformerMixin):
return Xfit
+class WassersteinDistance(BaseEstimator, TransformerMixin):
+ """
+ This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams.
+ """
+ def __init__(self, order=2, internal_p=2):
+ """
+ Constructor for the WassersteinDistance class.
+
+ Parameters:
+ order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`.
+ internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`.
+ """
+ self.order, self.internal_p = order, internal_p
+
+ def fit(self, X, y=None):
+ """
+ Fit the WassersteinDistance class on a list of persistence diagrams: persistence diagrams are stored in a numpy array called **diagrams**.
+
+ Parameters:
+ X (list of n x 2 numpy arrays): input persistence diagrams.
+ y (n x 1 array): persistence diagram labels (unused).
+ """
+ self.diagrams_ = X
+ return self
+
+ def transform(self, X):
+ """
+ Compute all Wasserstein distances between the persistence diagrams that were stored after calling the fit() method, and a given list of (possibly different) persistence diagrams.
+
+ Parameters:
+ X (list of n x 2 numpy arrays): input persistence diagrams.
+
+ Returns:
+ numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise Wasserstein distances.
+ """
+ num_diag1 = len(X)
+
+ #if len(self.diagrams_) == len(X) and np.all([np.array_equal(self.diagrams_[i], X[i]) for i in range(len(X))]):
+ if X is self.diagrams_:
+ matrix = np.zeros((num_diag1, num_diag1))
+
+ for i in range(num_diag1):
+ for j in range(i+1, num_diag1):
+ matrix[i,j] = wasserstein_distance(X[i], X[j], self.order, self.internal_p)
+ matrix[j,i] = matrix[i,j]
+
+ else:
+ num_diag2 = len(self.diagrams_)
+ matrix = np.zeros((num_diag1, num_diag2))
+
+ for i in range(num_diag1):
+ for j in range(num_diag2):
+ matrix[i,j] = wasserstein_distance(X[i], self.diagrams_[j], self.order, self.internal_p)
+
+ Xfit = matrix
+
+ return Xfit
+
class PersistenceFisherDistance(BaseEstimator, TransformerMixin):
"""
This is a class for computing the persistence Fisher distance matrix from a list of persistence diagrams. The persistence Fisher distance is obtained by computing the original Fisher distance between the probability distributions associated to the persistence diagrams given by convolving them with a Gaussian kernel. See http://papers.nips.cc/paper/8205-persistence-fisher-kernel-a-riemannian-manifold-kernel-for-persistence-diagrams for more details.