From a47ace987876cb52351ae9223d335629aedbd71e Mon Sep 17 00:00:00 2001 From: mathieu Date: Tue, 10 Mar 2020 19:44:57 -0400 Subject: new fixes --- src/python/gudhi/representations/metrics.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) (limited to 'src') diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py index c5439a67..0659b457 100644 --- a/src/python/gudhi/representations/metrics.py +++ b/src/python/gudhi/representations/metrics.py @@ -10,17 +10,9 @@ import numpy as np from sklearn.base import BaseEstimator, TransformerMixin from sklearn.metrics import pairwise_distances -from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance from gudhi.hera import wasserstein_distance as hera_wasserstein_distance from .preprocessing import Padding -try: - from .. import bottleneck_distance - USE_GUDHI = True -except ImportError: - USE_GUDHI = False - print("Gudhi built without CGAL: BottleneckDistance will return a null matrix") - ############################################# # Metrics ################################### ############################################# @@ -111,9 +103,13 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwa YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1]) if metric == "bottleneck": return pairwise_distances(XX, YY, metric=sklearn_wrapper(bottleneck_distance, X, Y, **kwargs)) - elif metric == "wasserstein" or metric == "pot_wasserstein": - return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, X, Y, **kwargs)) - elif metric == "hera_wasserstein": + elif metric == "pot_wasserstein": + try: + from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance + return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, X, Y, **kwargs)) + except ImportError: + print("Gudhi built without POT") + elif metric == "wasserstein" or metric == "hera_wasserstein": return pairwise_distances(XX, YY, metric=sklearn_wrapper(hera_wasserstein_distance, X, Y, **kwargs)) elif metric == "sliced_wasserstein": return pairwise_distances(XX, YY, metric=sklearn_wrapper(sliced_wasserstein_distance, X, Y, **kwargs)) @@ -192,16 +188,17 @@ class BottleneckDistance(BaseEstimator, TransformerMixin): Returns: numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise bottleneck distances. """ - if not USE_GUDHI: - print("Gudhi built without CGAL: returning a null matrix") - Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric="bottleneck", e=self.epsilon) if USE_GUDHI else np.zeros((len(X), len(self.diagrams_))) + try: + Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric="bottleneck", e=self.epsilon) + except ImportError: + print("Gudhi built without CGAL") return Xfit class WassersteinDistance(BaseEstimator, TransformerMixin): """ This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams. """ - def __init__(self, order=2, internal_p=2, mode="pot", delta=0.0001): + def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01): """ Constructor for the WassersteinDistance class. -- cgit v1.2.3