diff options
author | mathieu <mathieu.carriere3@gmail.com> | 2020-03-10 19:44:57 -0400 |
---|---|---|
committer | mathieu <mathieu.carriere3@gmail.com> | 2020-03-10 19:44:57 -0400 |
commit | a47ace987876cb52351ae9223d335629aedbd71e (patch) | |
tree | bd67ef705b9723847bad9ffea261274b41a3b261 /src/python/gudhi | |
parent | fda2bc369960da0f112b275d94bcf0b4adc3a3c7 (diff) |
new fixes
Diffstat (limited to 'src/python/gudhi')
-rw-r--r-- | src/python/gudhi/representations/metrics.py | 27 |
1 files changed, 12 insertions, 15 deletions
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py index c5439a67..0659b457 100644 --- a/src/python/gudhi/representations/metrics.py +++ b/src/python/gudhi/representations/metrics.py @@ -10,17 +10,9 @@ import numpy as np from sklearn.base import BaseEstimator, TransformerMixin from sklearn.metrics import pairwise_distances -from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance from gudhi.hera import wasserstein_distance as hera_wasserstein_distance from .preprocessing import Padding -try: - from .. import bottleneck_distance - USE_GUDHI = True -except ImportError: - USE_GUDHI = False - print("Gudhi built without CGAL: BottleneckDistance will return a null matrix") - ############################################# # Metrics ################################### ############################################# @@ -111,9 +103,13 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwa YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1]) if metric == "bottleneck": return pairwise_distances(XX, YY, metric=sklearn_wrapper(bottleneck_distance, X, Y, **kwargs)) - elif metric == "wasserstein" or metric == "pot_wasserstein": - return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, X, Y, **kwargs)) - elif metric == "hera_wasserstein": + elif metric == "pot_wasserstein": + try: + from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance + return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, X, Y, **kwargs)) + except ImportError: + print("Gudhi built without POT") + elif metric == "wasserstein" or metric == "hera_wasserstein": return pairwise_distances(XX, YY, metric=sklearn_wrapper(hera_wasserstein_distance, X, Y, **kwargs)) elif metric == "sliced_wasserstein": return pairwise_distances(XX, YY, metric=sklearn_wrapper(sliced_wasserstein_distance, X, Y, **kwargs)) @@ -192,16 +188,17 @@ class BottleneckDistance(BaseEstimator, TransformerMixin): Returns: numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise bottleneck distances. """ - if not USE_GUDHI: - print("Gudhi built without CGAL: returning a null matrix") - Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric="bottleneck", e=self.epsilon) if USE_GUDHI else np.zeros((len(X), len(self.diagrams_))) + try: + Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric="bottleneck", e=self.epsilon) + except ImportError: + print("Gudhi built without CGAL") return Xfit class WassersteinDistance(BaseEstimator, TransformerMixin): """ This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams. """ - def __init__(self, order=2, internal_p=2, mode="pot", delta=0.0001): + def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01): """ Constructor for the WassersteinDistance class. |