summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/kernel_methods.py
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2020-02-07 21:00:17 -0500
committerMathieuCarriere <mathieu.carriere3@gmail.com>2020-02-07 21:00:17 -0500
commit29e81d5038116aef0ec505e4d21d29f1c5920e34 (patch)
treef5665ca1bec314c443d8ae75338f8981ea2b58bc /src/python/gudhi/representations/kernel_methods.py
parentd21640a16113a3c56389efcb060b3430af9f256d (diff)
added sklearn trick
Diffstat (limited to 'src/python/gudhi/representations/kernel_methods.py')
-rw-r--r--src/python/gudhi/representations/kernel_methods.py20
1 files changed, 5 insertions, 15 deletions
diff --git a/src/python/gudhi/representations/kernel_methods.py b/src/python/gudhi/representations/kernel_methods.py
index bbbb7c31..d89f69ab 100644
--- a/src/python/gudhi/representations/kernel_methods.py
+++ b/src/python/gudhi/representations/kernel_methods.py
@@ -62,27 +62,17 @@ def pairwise_persistence_diagram_kernels(X, Y=None, metric="sliced_wasserstein",
:param metric: kernel to use. It can be either a string ("sliced_wasserstein", "persistence_scale_space", "persistence_weighted_gaussian", "persistence_fisher") or a function taking two numpy arrays of shape (nx2) and (mx2) as inputs.
:returns: kernel matrix, i.e., numpy array of shape (num diagrams 1 x num diagrams 2)
:rtype: float
- """
- if Y is None:
- YY = None
- pX = Padding(use=True).fit_transform(X)
- diag_len = len(pX[0])
- XX = np.reshape(np.vstack(pX), [-1, diag_len*3])
- else:
- nX, nY = len(X), len(Y)
- pD = Padding(use=True).fit_transform(X + Y)
- diag_len = len(pD[0])
- XX = np.reshape(np.vstack(pD[:nX]), [-1, diag_len*3])
- YY = np.reshape(np.vstack(pD[nX:]), [-1, diag_len*3])
-
+ """
+ XX = np.reshape(np.arange(len(X)), [-1,1])
+ YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1])
if metric == "sliced_wasserstein":
return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="sliced_wasserstein", num_directions=kwargs["num_directions"]) / kwargs["bandwidth"])
elif metric == "persistence_fisher":
return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="persistence_fisher", kernel_approx=kwargs["kernel_approx"], bandwidth=kwargs["bandwidth"]) / kwargs["bandwidth_fisher"])
elif metric == "persistence_scale_space":
- return pairwise_kernels(XX, YY, metric=sklearn_wrapper(persistence_scale_space_kernel, **kwargs))
+ return pairwise_kernels(XX, YY, metric=sklearn_wrapper(persistence_scale_space_kernel, X, Y, **kwargs))
elif metric == "persistence_weighted_gaussian":
- return pairwise_kernels(XX, YY, metric=sklearn_wrapper(persistence_weighted_gaussian_kernel, **kwargs))
+ return pairwise_kernels(XX, YY, metric=sklearn_wrapper(persistence_weighted_gaussian_kernel, X, Y, **kwargs))
else:
return pairwise_kernels(XX, YY, metric=sklearn_wrapper(metric, **kwargs))