diff options
Diffstat (limited to 'src/python/gudhi/representations/kernel_methods.py')
-rw-r--r-- | src/python/gudhi/representations/kernel_methods.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/src/python/gudhi/representations/kernel_methods.py b/src/python/gudhi/representations/kernel_methods.py index 596f4f07..c9bd9d01 100644 --- a/src/python/gudhi/representations/kernel_methods.py +++ b/src/python/gudhi/representations/kernel_methods.py @@ -10,7 +10,7 @@ import numpy as np from sklearn.base import BaseEstimator, TransformerMixin from sklearn.metrics import pairwise_distances, pairwise_kernels -from .metrics import SlicedWassersteinDistance, PersistenceFisherDistance, _sklearn_wrapper, pairwise_persistence_diagram_distances, _sliced_wasserstein_distance, _persistence_fisher_distance +from .metrics import SlicedWassersteinDistance, PersistenceFisherDistance, _sklearn_wrapper, _pairwise, pairwise_persistence_diagram_distances, _sliced_wasserstein_distance, _persistence_fisher_distance from .preprocessing import Padding ############################################# @@ -60,7 +60,7 @@ def _persistence_scale_space_kernel(D1, D2, kernel_approx=None, bandwidth=1.): weight_pss = lambda x: 1 if x[1] >= x[0] else -1 return 0.5 * _persistence_weighted_gaussian_kernel(DD1, DD2, weight=weight_pss, kernel_approx=kernel_approx, bandwidth=bandwidth) -def pairwise_persistence_diagram_kernels(X, Y=None, kernel="sliced_wasserstein", **kwargs): +def pairwise_persistence_diagram_kernels(X, Y=None, kernel="sliced_wasserstein", n_jobs=None, **kwargs): """ This function computes the kernel matrix between two lists of persistence diagrams given as numpy arrays of shape (nx2). @@ -76,15 +76,15 @@ def pairwise_persistence_diagram_kernels(X, Y=None, kernel="sliced_wasserstein", XX = np.reshape(np.arange(len(X)), [-1,1]) YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1]) if kernel == "sliced_wasserstein": - return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="sliced_wasserstein", num_directions=kwargs["num_directions"]) / kwargs["bandwidth"]) + return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="sliced_wasserstein", num_directions=kwargs["num_directions"], n_jobs=n_jobs) / kwargs["bandwidth"]) elif kernel == "persistence_fisher": - return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="persistence_fisher", kernel_approx=kwargs["kernel_approx"], bandwidth=kwargs["bandwidth"]) / kwargs["bandwidth_fisher"]) + return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="persistence_fisher", kernel_approx=kwargs["kernel_approx"], bandwidth=kwargs["bandwidth"], n_jobs=n_jobs) / kwargs["bandwidth_fisher"]) elif kernel == "persistence_scale_space": - return pairwise_kernels(XX, YY, metric=_sklearn_wrapper(_persistence_scale_space_kernel, X, Y, **kwargs)) + return _pairwise(pairwise_kernels, False, XX, YY, metric=_sklearn_wrapper(_persistence_scale_space_kernel, X, Y, **kwargs), n_jobs=n_jobs) elif kernel == "persistence_weighted_gaussian": - return pairwise_kernels(XX, YY, metric=_sklearn_wrapper(_persistence_weighted_gaussian_kernel, X, Y, **kwargs)) + return _pairwise(pairwise_kernels, False, XX, YY, metric=_sklearn_wrapper(_persistence_weighted_gaussian_kernel, X, Y, **kwargs), n_jobs=n_jobs) else: - return pairwise_kernels(XX, YY, metric=_sklearn_wrapper(metric, **kwargs)) + return _pairwise(pairwise_kernels, False, XX, YY, metric=_sklearn_wrapper(metric, **kwargs), n_jobs=n_jobs) class SlicedWassersteinKernel(BaseEstimator, TransformerMixin): """ |