summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/kernel_methods.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/representations/kernel_methods.py')
-rw-r--r--src/python/gudhi/representations/kernel_methods.py41
1 files changed, 25 insertions, 16 deletions
diff --git a/src/python/gudhi/representations/kernel_methods.py b/src/python/gudhi/representations/kernel_methods.py
index 596f4f07..23fd23c7 100644
--- a/src/python/gudhi/representations/kernel_methods.py
+++ b/src/python/gudhi/representations/kernel_methods.py
@@ -10,7 +10,7 @@
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import pairwise_distances, pairwise_kernels
-from .metrics import SlicedWassersteinDistance, PersistenceFisherDistance, _sklearn_wrapper, pairwise_persistence_diagram_distances, _sliced_wasserstein_distance, _persistence_fisher_distance
+from .metrics import SlicedWassersteinDistance, PersistenceFisherDistance, _sklearn_wrapper, _pairwise, pairwise_persistence_diagram_distances, _sliced_wasserstein_distance, _persistence_fisher_distance
from .preprocessing import Padding
#############################################
@@ -60,7 +60,7 @@ def _persistence_scale_space_kernel(D1, D2, kernel_approx=None, bandwidth=1.):
weight_pss = lambda x: 1 if x[1] >= x[0] else -1
return 0.5 * _persistence_weighted_gaussian_kernel(DD1, DD2, weight=weight_pss, kernel_approx=kernel_approx, bandwidth=bandwidth)
-def pairwise_persistence_diagram_kernels(X, Y=None, kernel="sliced_wasserstein", **kwargs):
+def pairwise_persistence_diagram_kernels(X, Y=None, kernel="sliced_wasserstein", n_jobs=None, **kwargs):
"""
This function computes the kernel matrix between two lists of persistence diagrams given as numpy arrays of shape (nx2).
@@ -68,38 +68,41 @@ def pairwise_persistence_diagram_kernels(X, Y=None, kernel="sliced_wasserstein",
X (list of n numpy arrays of shape (numx2)): first list of persistence diagrams.
Y (list of m numpy arrays of shape (numx2)): second list of persistence diagrams (optional). If None, pairwise kernel values are computed from the first list only.
kernel: kernel to use. It can be either a string ("sliced_wasserstein", "persistence_scale_space", "persistence_weighted_gaussian", "persistence_fisher") or a function taking two numpy arrays of shape (nx2) and (mx2) as inputs. If it is a function, make sure that it is symmetric.
+ n_jobs (int): number of jobs to use for the computation. This uses joblib.Parallel(prefer="threads"), so kernels that do not release the GIL may not scale unless run inside a `joblib.parallel_backend <https://joblib.readthedocs.io/en/latest/parallel.html#joblib.parallel_backend>`_ block.
**kwargs: optional keyword parameters. Any further parameters are passed directly to the kernel function. See the docs of the various kernel classes in this module.
Returns:
numpy array of shape (nxm): kernel matrix.
"""
XX = np.reshape(np.arange(len(X)), [-1,1])
- YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1])
+ YY = None if Y is None or Y is X else np.reshape(np.arange(len(Y)), [-1,1])
if kernel == "sliced_wasserstein":
- return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="sliced_wasserstein", num_directions=kwargs["num_directions"]) / kwargs["bandwidth"])
+ return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="sliced_wasserstein", num_directions=kwargs["num_directions"], n_jobs=n_jobs) / kwargs["bandwidth"])
elif kernel == "persistence_fisher":
- return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="persistence_fisher", kernel_approx=kwargs["kernel_approx"], bandwidth=kwargs["bandwidth"]) / kwargs["bandwidth_fisher"])
+ return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="persistence_fisher", kernel_approx=kwargs["kernel_approx"], bandwidth=kwargs["bandwidth"], n_jobs=n_jobs) / kwargs["bandwidth_fisher"])
elif kernel == "persistence_scale_space":
- return pairwise_kernels(XX, YY, metric=_sklearn_wrapper(_persistence_scale_space_kernel, X, Y, **kwargs))
+ return _pairwise(pairwise_kernels, False, XX, YY, metric=_sklearn_wrapper(_persistence_scale_space_kernel, X, Y, **kwargs), n_jobs=n_jobs)
elif kernel == "persistence_weighted_gaussian":
- return pairwise_kernels(XX, YY, metric=_sklearn_wrapper(_persistence_weighted_gaussian_kernel, X, Y, **kwargs))
+ return _pairwise(pairwise_kernels, False, XX, YY, metric=_sklearn_wrapper(_persistence_weighted_gaussian_kernel, X, Y, **kwargs), n_jobs=n_jobs)
else:
- return pairwise_kernels(XX, YY, metric=_sklearn_wrapper(metric, **kwargs))
+ return _pairwise(pairwise_kernels, False, XX, YY, metric=_sklearn_wrapper(metric, **kwargs), n_jobs=n_jobs)
class SlicedWassersteinKernel(BaseEstimator, TransformerMixin):
"""
This is a class for computing the sliced Wasserstein kernel matrix from a list of persistence diagrams. The sliced Wasserstein kernel is computed by exponentiating the corresponding sliced Wasserstein distance with a Gaussian kernel. See http://proceedings.mlr.press/v70/carriere17a.html for more details.
"""
- def __init__(self, num_directions=10, bandwidth=1.0):
+ def __init__(self, num_directions=10, bandwidth=1.0, n_jobs=None):
"""
Constructor for the SlicedWassersteinKernel class.
Parameters:
bandwidth (double): bandwidth of the Gaussian kernel applied to the sliced Wasserstein distance (default 1.).
num_directions (int): number of lines evenly sampled from [-pi/2,pi/2] in order to approximate and speed up the kernel computation (default 10).
+ n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_persistence_diagram_kernels` for details.
"""
self.bandwidth = bandwidth
self.num_directions = num_directions
+ self.n_jobs = n_jobs
def fit(self, X, y=None):
"""
@@ -122,7 +125,7 @@ class SlicedWassersteinKernel(BaseEstimator, TransformerMixin):
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise sliced Wasserstein kernel values.
"""
- return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="sliced_wasserstein", bandwidth=self.bandwidth, num_directions=self.num_directions)
+ return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="sliced_wasserstein", bandwidth=self.bandwidth, num_directions=self.num_directions, n_jobs=self.n_jobs)
def __call__(self, diag1, diag2):
"""
@@ -141,7 +144,7 @@ class PersistenceWeightedGaussianKernel(BaseEstimator, TransformerMixin):
"""
This is a class for computing the persistence weighted Gaussian kernel matrix from a list of persistence diagrams. The persistence weighted Gaussian kernel is computed by convolving the persistence diagram points with weighted Gaussian kernels. See http://proceedings.mlr.press/v48/kusano16.html for more details.
"""
- def __init__(self, bandwidth=1., weight=lambda x: 1, kernel_approx=None):
+ def __init__(self, bandwidth=1., weight=lambda x: 1, kernel_approx=None, n_jobs=None):
"""
Constructor for the PersistenceWeightedGaussianKernel class.
@@ -149,9 +152,11 @@ class PersistenceWeightedGaussianKernel(BaseEstimator, TransformerMixin):
bandwidth (double): bandwidth of the Gaussian kernel with which persistence diagrams will be convolved (default 1.)
weight (function): weight function for the persistence diagram points (default constant function, ie lambda x: 1). This function must be defined on 2D points, ie lists or numpy arrays of the form [p_x,p_y].
kernel_approx (class): kernel approximation class used to speed up computation (default None). Common kernel approximations classes can be found in the scikit-learn library (such as RBFSampler for instance).
+ n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_persistence_diagram_kernels` for details.
"""
self.bandwidth, self.weight = bandwidth, weight
self.kernel_approx = kernel_approx
+ self.n_jobs = n_jobs
def fit(self, X, y=None):
"""
@@ -174,7 +179,7 @@ class PersistenceWeightedGaussianKernel(BaseEstimator, TransformerMixin):
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence weighted Gaussian kernel values.
"""
- return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="persistence_weighted_gaussian", bandwidth=self.bandwidth, weight=self.weight, kernel_approx=self.kernel_approx)
+ return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="persistence_weighted_gaussian", bandwidth=self.bandwidth, weight=self.weight, kernel_approx=self.kernel_approx, n_jobs=self.n_jobs)
def __call__(self, diag1, diag2):
"""
@@ -193,15 +198,17 @@ class PersistenceScaleSpaceKernel(BaseEstimator, TransformerMixin):
"""
This is a class for computing the persistence scale space kernel matrix from a list of persistence diagrams. The persistence scale space kernel is computed by adding the symmetric to the diagonal of each point in each persistence diagram, with negative weight, and then convolving the points with a Gaussian kernel. See https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Reininghaus_A_Stable_Multi-Scale_2015_CVPR_paper.pdf for more details.
"""
- def __init__(self, bandwidth=1., kernel_approx=None):
+ def __init__(self, bandwidth=1., kernel_approx=None, n_jobs=None):
"""
Constructor for the PersistenceScaleSpaceKernel class.
Parameters:
bandwidth (double): bandwidth of the Gaussian kernel with which persistence diagrams will be convolved (default 1.)
kernel_approx (class): kernel approximation class used to speed up computation (default None). Common kernel approximations classes can be found in the scikit-learn library (such as RBFSampler for instance).
+ n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_persistence_diagram_kernels` for details.
"""
self.bandwidth, self.kernel_approx = bandwidth, kernel_approx
+ self.n_jobs = n_jobs
def fit(self, X, y=None):
"""
@@ -224,7 +231,7 @@ class PersistenceScaleSpaceKernel(BaseEstimator, TransformerMixin):
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence scale space kernel values.
"""
- return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="persistence_scale_space", bandwidth=self.bandwidth, kernel_approx=self.kernel_approx)
+ return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="persistence_scale_space", bandwidth=self.bandwidth, kernel_approx=self.kernel_approx, n_jobs=self.n_jobs)
def __call__(self, diag1, diag2):
"""
@@ -243,7 +250,7 @@ class PersistenceFisherKernel(BaseEstimator, TransformerMixin):
"""
This is a class for computing the persistence Fisher kernel matrix from a list of persistence diagrams. The persistence Fisher kernel is computed by exponentiating the corresponding persistence Fisher distance with a Gaussian kernel. See papers.nips.cc/paper/8205-persistence-fisher-kernel-a-riemannian-manifold-kernel-for-persistence-diagrams for more details.
"""
- def __init__(self, bandwidth_fisher=1., bandwidth=1., kernel_approx=None):
+ def __init__(self, bandwidth_fisher=1., bandwidth=1., kernel_approx=None, n_jobs=None):
"""
Constructor for the PersistenceFisherKernel class.
@@ -251,9 +258,11 @@ class PersistenceFisherKernel(BaseEstimator, TransformerMixin):
bandwidth (double): bandwidth of the Gaussian kernel applied to the persistence Fisher distance (default 1.).
bandwidth_fisher (double): bandwidth of the Gaussian kernel used to turn persistence diagrams into probability distributions by PersistenceFisherDistance class (default 1.).
kernel_approx (class): kernel approximation class used to speed up computation (default None). Common kernel approximations classes can be found in the scikit-learn library (such as RBFSampler for instance).
+ n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_persistence_diagram_kernels` for details.
"""
self.bandwidth = bandwidth
self.bandwidth_fisher, self.kernel_approx = bandwidth_fisher, kernel_approx
+ self.n_jobs = n_jobs
def fit(self, X, y=None):
"""
@@ -276,7 +285,7 @@ class PersistenceFisherKernel(BaseEstimator, TransformerMixin):
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise persistence Fisher kernel values.
"""
- return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="persistence_fisher", bandwidth=self.bandwidth, bandwidth_fisher=self.bandwidth_fisher, kernel_approx=self.kernel_approx)
+ return pairwise_persistence_diagram_kernels(X, self.diagrams_, kernel="persistence_fisher", bandwidth=self.bandwidth, bandwidth_fisher=self.bandwidth_fisher, kernel_approx=self.kernel_approx, n_jobs=self.n_jobs)
def __call__(self, diag1, diag2):
"""