summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/metrics.py
diff options
context:
space:
mode:
authormathieu <mathieu.carriere3@gmail.com>2020-02-13 16:09:22 -0500
committermathieu <mathieu.carriere3@gmail.com>2020-02-13 16:09:22 -0500
commit7327abc115ae3e06a512782ec5833783086c3866 (patch)
treea426c589eee5849f157bbdd96d05d1345746d222 /src/python/gudhi/representations/metrics.py
parentef0f82ef2155440827e17c552abb49b509866fc7 (diff)
parent29e81d5038116aef0ec505e4d21d29f1c5920e34 (diff)
integrated hera
Diffstat (limited to 'src/python/gudhi/representations/metrics.py')
-rw-r--r--src/python/gudhi/representations/metrics.py39
1 files changed, 16 insertions, 23 deletions
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index ed998603..c5439a67 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -86,13 +86,16 @@ def persistence_fisher_distance(D1, D2, kernel_approx=None, bandwidth=1.):
vectorj = vectorj/vectorj_sum
return np.arccos( min(np.dot(np.sqrt(vectori), np.sqrt(vectorj)), 1.) )
-def sklearn_wrapper(metric, **kwargs):
+def sklearn_wrapper(metric, X, Y, **kwargs):
"""
- This function is a wrapper for any metric between two persistence diagrams that takes two numpy arrays of shapes (nx2) and (mx2) as arguments. It turns the metric into another that takes flattened and padded diagrams as inputs.
+ This function is a wrapper for any metric between two persistence diagrams that takes two numpy arrays of shapes (nx2) and (mx2) as arguments.
"""
- def flat_metric(D1, D2):
- DD1, DD2 = np.reshape(D1, [-1,3]), np.reshape(D2, [-1,3])
- return metric(DD1[DD1[:,2]==1,0:2], DD2[DD2[:,2]==1,0:2], **kwargs)
+ if Y is None:
+ def flat_metric(a, b):
+ return metric(X[int(a[0])], X[int(b[0])], **kwargs)
+ else:
+ def flat_metric(a, b):
+ return metric(X[int(a[0])], Y[int(b[0])], **kwargs)
return flat_metric
def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwargs):
@@ -104,30 +107,20 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwa
:returns: distance matrix, i.e., numpy array of shape (num diagrams 1 x num diagrams 2)
:rtype: float
"""
- if Y is None:
- YY = None
- pX = Padding(use=True).fit_transform(X)
- diag_len = len(pX[0])
- XX = np.reshape(np.vstack(pX), [-1, diag_len*3])
- else:
- nX, nY = len(X), len(Y)
- pD = Padding(use=True).fit_transform(X + Y)
- diag_len = len(pD[0])
- XX = np.reshape(np.vstack(pD[:nX]), [-1, diag_len*3])
- YY = np.reshape(np.vstack(pD[nX:]), [-1, diag_len*3])
-
+ XX = np.reshape(np.arange(len(X)), [-1,1])
+ YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1])
if metric == "bottleneck":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(bottleneck_distance, **kwargs))
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(bottleneck_distance, X, Y, **kwargs))
elif metric == "wasserstein" or metric == "pot_wasserstein":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, **kwargs))
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, X, Y, **kwargs))
elif metric == "hera_wasserstein":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(hera_wasserstein_distance, **kwargs))
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(hera_wasserstein_distance, X, Y, **kwargs))
elif metric == "sliced_wasserstein":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(sliced_wasserstein_distance, **kwargs))
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(sliced_wasserstein_distance, X, Y, **kwargs))
elif metric == "persistence_fisher":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(persistence_fisher_distance, **kwargs))
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(persistence_fisher_distance, X, Y, **kwargs))
else:
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(metric, **kwargs))
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(metric, X, Y, **kwargs))
class SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
"""