summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/kernel_methods.py
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2020-04-28 13:28:10 -0400
committerMathieuCarriere <mathieu.carriere3@gmail.com>2020-04-28 13:28:10 -0400
commit87311ec2d59211320e763bc9bc531858b489ff7e (patch)
treeb9f6d88970315c7e8ba8015315692c270873de21 /src/python/gudhi/representations/kernel_methods.py
parent51ca2370ae27bec052bb4fffafc8a718ac306264 (diff)
added call methods + other fixes
Diffstat (limited to 'src/python/gudhi/representations/kernel_methods.py')
-rw-r--r--src/python/gudhi/representations/kernel_methods.py88
1 files changed, 70 insertions, 18 deletions
diff --git a/src/python/gudhi/representations/kernel_methods.py b/src/python/gudhi/representations/kernel_methods.py
index 50186d63..edd1382a 100644
--- a/src/python/gudhi/representations/kernel_methods.py
+++ b/src/python/gudhi/representations/kernel_methods.py
@@ -10,14 +10,14 @@
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import pairwise_distances, pairwise_kernels
-from .metrics import SlicedWassersteinDistance, PersistenceFisherDistance, sklearn_wrapper, pairwise_persistence_diagram_distances, sliced_wasserstein_distance, persistence_fisher_distance
+from .metrics import SlicedWassersteinDistance, PersistenceFisherDistance, _sklearn_wrapper, pairwise_persistence_diagram_distances, _sliced_wasserstein_distance, _persistence_fisher_distance
from .preprocessing import Padding
#############################################
# Kernel methods ############################
#############################################
-def persistence_weighted_gaussian_kernel(D1, D2, weight=lambda x: 1, kernel_approx=None, bandwidth=1.):
+def _persistence_weighted_gaussian_kernel(D1, D2, weight=lambda x: 1, kernel_approx=None, bandwidth=1.):
"""
This is a function for computing the persistence weighted Gaussian kernel value from two persistence diagrams. The persistence weighted Gaussian kernel is computed by convolving the persistence diagram points with weighted Gaussian kernels. See http://proceedings.mlr.press/v48/kusano16.html for more details.
@@ -25,7 +25,7 @@ def persistence_weighted_gaussian_kernel(D1, D2, weight=lambda x: 1, kernel_appr
D1: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate).
D2: (m x 2) numpy.array encoding the second diagram.
bandwidth (double): bandwidth of the Gaussian kernel with which persistence diagrams will be convolved
- weight: weight function for the persistence diagram points. This function must be defined on 2D points, ie lists or numpy arrays of the form [p_x,p_y].
+ weight: weight function for the persistence diagram points (default constant function, ie lambda x: 1). This function must be defined on 2D points, ie lists or numpy arrays of the form [p_x,p_y].
kernel_approx: kernel approximation class used to speed up computation. Common kernel approximations classes can be found in the scikit-learn library (such as RBFSampler for instance).
Returns:
@@ -42,7 +42,7 @@ def persistence_weighted_gaussian_kernel(D1, D2, weight=lambda x: 1, kernel_appr
E = (1./(np.sqrt(2*np.pi)*bandwidth)) * np.exp(-np.square(pairwise_distances(D1,D2))/(2*bandwidth*bandwidth))
return np.sum(np.multiply(W, E))
-def persistence_scale_space_kernel(D1, D2, kernel_approx=None, bandwidth=1.):
+def _persistence_scale_space_kernel(D1, D2, kernel_approx=None, bandwidth=1.):
"""
This is a function for computing the persistence scale space kernel value from two persistence diagrams. The persistence scale space kernel is computed by adding the symmetric to the diagonal of each point in each persistence diagram, with negative weight, and then convolving the points with a Gaussian kernel. See https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Reininghaus_A_Stable_Multi-Scale_2015_CVPR_paper.pdf for more details.
@@ -58,32 +58,32 @@ def persistence_scale_space_kernel(D1, D2, kernel_approx=None, bandwidth=1.):
DD1 = np.concatenate([D1, D1[:,[1,0]]], axis=0)
DD2 = np.concatenate([D2, D2[:,[1,0]]], axis=0)
weight_pss = lambda x: 1 if x[1] >= x[0] else -1
- return 0.5 * persistence_weighted_gaussian_kernel(DD1, DD2, weight=weight_pss, kernel_approx=kernel_approx, bandwidth=bandwidth)
+ return 0.5 * _persistence_weighted_gaussian_kernel(DD1, DD2, weight=weight_pss, kernel_approx=kernel_approx, bandwidth=bandwidth)
-def pairwise_persistence_diagram_kernels(X, Y=None, metric="sliced_wasserstein", **kwargs):
+def pairwise_persistence_diagram_kernels(X, Y=None, kernel="sliced_wasserstein", **kwargs):
"""
This function computes the kernel matrix between two lists of persistence diagrams given as numpy arrays of shape (nx2).
Parameters:
X (list of n numpy arrays of shape (numx2)): first list of persistence diagrams.
Y (list of m numpy arrays of shape (numx2)): second list of persistence diagrams (optional). If None, pairwise kernel values are computed from the first list only.
- metric: kernel to use. It can be either a string ("sliced_wasserstein", "persistence_scale_space", "persistence_weighted_gaussian", "persistence_fisher") or a function taking two numpy arrays of shape (nx2) and (mx2) as inputs.
+ kernel: kernel to use. It can be either a string ("sliced_wasserstein", "persistence_scale_space", "persistence_weighted_gaussian", "persistence_fisher") or a function taking two numpy arrays of shape (nx2) and (mx2) as inputs.
Returns:
numpy array of shape (nxm): kernel matrix.
"""
XX = np.reshape(np.arange(len(X)), [-1,1])
YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1])
- if metric == "sliced_wasserstein":
+ if kernel == "sliced_wasserstein":
return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="sliced_wasserstein", num_directions=kwargs["num_directions"]) / kwargs["bandwidth"])
- elif metric == "persistence_fisher":
+ elif kernel == "persistence_fisher":
return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="persistence_fisher", kernel_approx=kwargs["kernel_approx"], bandwidth=kwargs["bandwidth"]) / kwargs["bandwidth_fisher"])
- elif metric == "persistence_scale_space":
- return pairwise_kernels(XX, YY, metric=sklearn_wrapper(persistence_scale_space_kernel, X, Y, **kwargs))
- elif metric == "persistence_weighted_gaussian":
- return pairwise_kernels(XX, YY, metric=sklearn_wrapper(persistence_weighted_gaussian_kernel, X, Y, **kwargs))
+ elif kernel == "persistence_scale_space":
+ return pairwise_kernels(XX, YY, metric=_sklearn_wrapper(_persistence_scale_space_kernel, X, Y, **kwargs))
+ elif kernel == "persistence_weighted_gaussian":
+ return pairwise_kernels(XX, YY, metric=_sklearn_wrapper(_persistence_weighted_gaussian_kernel, X, Y, **kwargs))
else:
- return pairwise_kernels(XX, YY, metric=sklearn_wrapper(metric, **kwargs))
+ return pairwise_kernels(XX, YY, metric=_sklearn_wrapper(metric, **kwargs))
class SlicedWassersteinKernel(BaseEstimator, TransformerMixin):
"""
@@ -121,7 +121,20 @@ class SlicedWassersteinKernel(BaseEstimator, TransformerMixin):
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise sliced Wasserstein kernel values.
"""
- return pairwise_persistence_diagram_kernels(X, self.diagrams_, metric="sliced_wasserstein", bandwidth=self.bandwidth, num_directions=self.num_directions)
+ return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="sliced_wasserstein", bandwidth=self.bandwidth, num_directions=self.num_directions)
+
+ def __call__(self, diag1, diag2):
+ """
+ Apply SlicedWassersteinKernel on a single pair of persistence diagrams and outputs the result.
+
+ Parameters:
+ diag1 (n x 2 numpy array): first input persistence diagram.
+ diag2 (n x 2 numpy array): second input persistence diagram.
+
+ Returns:
+ float: sliced Wasserstein kernel value.
+ """
+ return np.exp(-_sliced_wasserstein_distance(diag1, diag2, num_directions=self.num_directions)) / self.bandwidth
class PersistenceWeightedGaussianKernel(BaseEstimator, TransformerMixin):
"""
@@ -160,7 +173,20 @@ class PersistenceWeightedGaussianKernel(BaseEstimator, TransformerMixin):
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence weighted Gaussian kernel values.
"""
- return pairwise_persistence_diagram_kernels(X, self.diagrams_, metric="persistence_weighted_gaussian", bandwidth=self.bandwidth, weight=self.weight, kernel_approx=self.kernel_approx)
+ return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="persistence_weighted_gaussian", bandwidth=self.bandwidth, weight=self.weight, kernel_approx=self.kernel_approx)
+
+ def __call__(self, diag1, diag2):
+ """
+ Apply PersistenceWeightedGaussianKernel on a single pair of persistence diagrams and outputs the result.
+
+ Parameters:
+ diag1 (n x 2 numpy array): first input persistence diagram.
+ diag2 (n x 2 numpy array): second input persistence diagram.
+
+ Returns:
+ float: persistence weighted Gaussian kernel value.
+ """
+ return _persistence_weighted_gaussian_kernel(diag1, diag2, weight=self.weight, kernel_approx=self.kernel_approx, bandwidth=self.bandwidth)
class PersistenceScaleSpaceKernel(BaseEstimator, TransformerMixin):
"""
@@ -197,7 +223,20 @@ class PersistenceScaleSpaceKernel(BaseEstimator, TransformerMixin):
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence scale space kernel values.
"""
- return pairwise_persistence_diagram_kernels(X, self.diagrams_, metric="persistence_scale_space", bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)
+ return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="persistence_scale_space", bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)
+
+ def __call__(self, diag1, diag2):
+ """
+ Apply PersistenceScaleSpaceKernel on a single pair of persistence diagrams and outputs the result.
+
+ Parameters:
+ diag1 (n x 2 numpy array): first input persistence diagram.
+ diag2 (n x 2 numpy array): second input persistence diagram.
+
+ Returns:
+ float: persistence scale space kernel value.
+ """
+ return _persistence_scale_space_kernel(diag1, diag2, bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)
class PersistenceFisherKernel(BaseEstimator, TransformerMixin):
"""
@@ -236,5 +275,18 @@ class PersistenceFisherKernel(BaseEstimator, TransformerMixin):
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence Fisher kernel values.
"""
- return pairwise_persistence_diagram_kernels(X, self.diagrams_, metric="persistence_fisher", bandwidth=self.bandwidth, bandwidth_fisher=self.bandwidth_fisher, kernel_approx=self.kernel_approx)
+ return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="persistence_fisher", bandwidth=self.bandwidth, bandwidth_fisher=self.bandwidth_fisher, kernel_approx=self.kernel_approx)
+
+ def __call__(self, diag1, diag2):
+ """
+ Apply PersistenceFisherKernel on a single pair of persistence diagrams and outputs the result.
+
+ Parameters:
+ diag1 (n x 2 numpy array): first input persistence diagram.
+ diag2 (n x 2 numpy array): second input persistence diagram.
+
+ Returns:
+ float: persistence Fisher kernel value.
+ """
+ return np.exp(-_persistence_fisher_distance(diag1, diag2, bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)) / self.bandwidth_fisher