diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-06-02 07:33:31 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-06-02 07:33:31 +0200 |
commit | cc42bcdf3323f2eb6edeaca105d29b32b394ca66 (patch) | |
tree | ebf0a9419169272c5cad23de892dec843b55294d /src/python/gudhi/representations/metrics.py | |
parent | c53567c85f936f78000471fcee6234e75f7742ca (diff) |
Parallelism in pairwise_distances
Diffstat (limited to 'src/python/gudhi/representations/metrics.py')
-rw-r--r-- | src/python/gudhi/representations/metrics.py | 27 |
1 files changed, 21 insertions, 6 deletions
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py index 8a32f7e9..23bccd68 100644 --- a/src/python/gudhi/representations/metrics.py +++ b/src/python/gudhi/representations/metrics.py @@ -12,6 +12,7 @@ from sklearn.base import BaseEstimator, TransformerMixin from sklearn.metrics import pairwise_distances from gudhi.hera import wasserstein_distance as hera_wasserstein_distance from .preprocessing import Padding +from joblib import Parallel, delayed, effective_n_jobs ############################################# # Metrics ################################### @@ -116,6 +117,20 @@ def _persistence_fisher_distance(D1, D2, kernel_approx=None, bandwidth=1.): vectorj = vectorj/vectorj_sum return np.arccos( min(np.dot(np.sqrt(vectori), np.sqrt(vectorj)), 1.) ) +def _pairwise(fallback, skipdiag, X, Y, metric, n_jobs): + if Y is not None: + return fallback(X, Y, metric=metric, n_jobs=n_jobs) + triu = np.triu_indices(len(X), k=skipdiag) + tril = (triu[1], triu[0]) + par = Parallel(n_jobs=n_jobs, prefer="threads") + d = par(delayed(metric)([triu[0][i]], [triu[1][i]]) for i in range(len(triu[0]))) + m = np.empty((len(X), len(X))) + m[triu] = d + m[tril] = d + if skipdiag: + np.fill_diagonal(m, 0) + return m + def _sklearn_wrapper(metric, X, Y, **kwargs): """ This function is a wrapper for any metric between two persistence diagrams that takes two numpy arrays of shapes (nx2) and (mx2) as arguments. @@ -134,7 +149,7 @@ PAIRWISE_DISTANCE_FUNCTIONS = { "persistence_fisher": _persistence_fisher_distance, } -def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwargs): +def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", n_jobs=None, **kwargs): """ This function computes the distance matrix between two lists of persistence diagrams given as numpy arrays of shape (nx2). @@ -152,25 +167,25 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwa if metric == "bottleneck": try: from .. import bottleneck_distance - return pairwise_distances(XX, YY, metric=_sklearn_wrapper(bottleneck_distance, X, Y, **kwargs)) + return _pairwise(pairwise_distances, True, XX, YY, metric=_sklearn_wrapper(bottleneck_distance, X, Y, **kwargs), n_jobs=n_jobs) except ImportError: print("Gudhi built without CGAL") raise elif metric == "pot_wasserstein": try: from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance - return pairwise_distances(XX, YY, metric=_sklearn_wrapper(pot_wasserstein_distance, X, Y, **kwargs)) + return _pairwise(pairwise_distances, True, XX, YY, metric=_sklearn_wrapper(pot_wasserstein_distance, X, Y, **kwargs), n_jobs=n_jobs) except ImportError: print("POT (Python Optimal Transport) is not installed. Please install POT or use metric='wasserstein' or metric='hera_wasserstein'") raise elif metric == "sliced_wasserstein": Xproj = _compute_persistence_diagram_projections(X, **kwargs) Yproj = None if Y is None else _compute_persistence_diagram_projections(Y, **kwargs) - return pairwise_distances(XX, YY, metric=_sklearn_wrapper(_sliced_wasserstein_distance_on_projections, Xproj, Yproj)) + return _pairwise(pairwise_distances, True, XX, YY, metric=_sklearn_wrapper(_sliced_wasserstein_distance_on_projections, Xproj, Yproj), n_jobs=n_jobs) elif type(metric) == str: - return pairwise_distances(XX, YY, metric=_sklearn_wrapper(PAIRWISE_DISTANCE_FUNCTIONS[metric], X, Y, **kwargs)) + return _pairwise(pairwise_distances, True, XX, YY, metric=_sklearn_wrapper(PAIRWISE_DISTANCE_FUNCTIONS[metric], X, Y, **kwargs), n_jobs=n_jobs) else: - return pairwise_distances(XX, YY, metric=_sklearn_wrapper(metric, X, Y, **kwargs)) + return _pairwise(pairwise_distances, True, XX, YY, metric=_sklearn_wrapper(metric, X, Y, **kwargs), n_jobs=n_jobs) class SlicedWassersteinDistance(BaseEstimator, TransformerMixin): """ |