summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormathieu <mathieu.carriere3@gmail.com>2020-02-13 16:09:22 -0500
committermathieu <mathieu.carriere3@gmail.com>2020-02-13 16:09:22 -0500
commit7327abc115ae3e06a512782ec5833783086c3866 (patch)
treea426c589eee5849f157bbdd96d05d1345746d222
parentef0f82ef2155440827e17c552abb49b509866fc7 (diff)
parent29e81d5038116aef0ec505e4d21d29f1c5920e34 (diff)
integrated hera
-rw-r--r--src/python/gudhi/representations/kernel_methods.py20
-rw-r--r--src/python/gudhi/representations/metrics.py39
2 files changed, 21 insertions, 38 deletions
diff --git a/src/python/gudhi/representations/kernel_methods.py b/src/python/gudhi/representations/kernel_methods.py
index bbbb7c31..d89f69ab 100644
--- a/src/python/gudhi/representations/kernel_methods.py
+++ b/src/python/gudhi/representations/kernel_methods.py
@@ -62,27 +62,17 @@ def pairwise_persistence_diagram_kernels(X, Y=None, metric="sliced_wasserstein",
:param metric: kernel to use. It can be either a string ("sliced_wasserstein", "persistence_scale_space", "persistence_weighted_gaussian", "persistence_fisher") or a function taking two numpy arrays of shape (nx2) and (mx2) as inputs.
:returns: kernel matrix, i.e., numpy array of shape (num diagrams 1 x num diagrams 2)
:rtype: float
- """
- if Y is None:
- YY = None
- pX = Padding(use=True).fit_transform(X)
- diag_len = len(pX[0])
- XX = np.reshape(np.vstack(pX), [-1, diag_len*3])
- else:
- nX, nY = len(X), len(Y)
- pD = Padding(use=True).fit_transform(X + Y)
- diag_len = len(pD[0])
- XX = np.reshape(np.vstack(pD[:nX]), [-1, diag_len*3])
- YY = np.reshape(np.vstack(pD[nX:]), [-1, diag_len*3])
-
+ """
+ XX = np.reshape(np.arange(len(X)), [-1,1])
+ YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1])
if metric == "sliced_wasserstein":
return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="sliced_wasserstein", num_directions=kwargs["num_directions"]) / kwargs["bandwidth"])
elif metric == "persistence_fisher":
return np.exp(-pairwise_persistence_diagram_distances(X, Y, metric="persistence_fisher", kernel_approx=kwargs["kernel_approx"], bandwidth=kwargs["bandwidth"]) / kwargs["bandwidth_fisher"])
elif metric == "persistence_scale_space":
- return pairwise_kernels(XX, YY, metric=sklearn_wrapper(persistence_scale_space_kernel, **kwargs))
+ return pairwise_kernels(XX, YY, metric=sklearn_wrapper(persistence_scale_space_kernel, X, Y, **kwargs))
elif metric == "persistence_weighted_gaussian":
- return pairwise_kernels(XX, YY, metric=sklearn_wrapper(persistence_weighted_gaussian_kernel, **kwargs))
+ return pairwise_kernels(XX, YY, metric=sklearn_wrapper(persistence_weighted_gaussian_kernel, X, Y, **kwargs))
else:
return pairwise_kernels(XX, YY, metric=sklearn_wrapper(metric, **kwargs))
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index ed998603..c5439a67 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -86,13 +86,16 @@ def persistence_fisher_distance(D1, D2, kernel_approx=None, bandwidth=1.):
vectorj = vectorj/vectorj_sum
return np.arccos( min(np.dot(np.sqrt(vectori), np.sqrt(vectorj)), 1.) )
-def sklearn_wrapper(metric, **kwargs):
+def sklearn_wrapper(metric, X, Y, **kwargs):
"""
- This function is a wrapper for any metric between two persistence diagrams that takes two numpy arrays of shapes (nx2) and (mx2) as arguments. It turns the metric into another that takes flattened and padded diagrams as inputs.
+ This function is a wrapper for any metric between two persistence diagrams that takes two numpy arrays of shapes (nx2) and (mx2) as arguments.
"""
- def flat_metric(D1, D2):
- DD1, DD2 = np.reshape(D1, [-1,3]), np.reshape(D2, [-1,3])
- return metric(DD1[DD1[:,2]==1,0:2], DD2[DD2[:,2]==1,0:2], **kwargs)
+ if Y is None:
+ def flat_metric(a, b):
+ return metric(X[int(a[0])], X[int(b[0])], **kwargs)
+ else:
+ def flat_metric(a, b):
+ return metric(X[int(a[0])], Y[int(b[0])], **kwargs)
return flat_metric
def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwargs):
@@ -104,30 +107,20 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwa
:returns: distance matrix, i.e., numpy array of shape (num diagrams 1 x num diagrams 2)
:rtype: float
"""
- if Y is None:
- YY = None
- pX = Padding(use=True).fit_transform(X)
- diag_len = len(pX[0])
- XX = np.reshape(np.vstack(pX), [-1, diag_len*3])
- else:
- nX, nY = len(X), len(Y)
- pD = Padding(use=True).fit_transform(X + Y)
- diag_len = len(pD[0])
- XX = np.reshape(np.vstack(pD[:nX]), [-1, diag_len*3])
- YY = np.reshape(np.vstack(pD[nX:]), [-1, diag_len*3])
-
+ XX = np.reshape(np.arange(len(X)), [-1,1])
+ YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1])
if metric == "bottleneck":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(bottleneck_distance, **kwargs))
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(bottleneck_distance, X, Y, **kwargs))
elif metric == "wasserstein" or metric == "pot_wasserstein":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, **kwargs))
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, X, Y, **kwargs))
elif metric == "hera_wasserstein":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(hera_wasserstein_distance, **kwargs))
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(hera_wasserstein_distance, X, Y, **kwargs))
elif metric == "sliced_wasserstein":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(sliced_wasserstein_distance, **kwargs))
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(sliced_wasserstein_distance, X, Y, **kwargs))
elif metric == "persistence_fisher":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(persistence_fisher_distance, **kwargs))
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(persistence_fisher_distance, X, Y, **kwargs))
else:
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(metric, **kwargs))
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(metric, X, Y, **kwargs))
class SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
"""