summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/metrics.py
diff options
context:
space:
mode:
authormathieu <mathieu.carriere3@gmail.com>2020-03-10 19:44:57 -0400
committermathieu <mathieu.carriere3@gmail.com>2020-03-10 19:44:57 -0400
commita47ace987876cb52351ae9223d335629aedbd71e (patch)
treebd67ef705b9723847bad9ffea261274b41a3b261 /src/python/gudhi/representations/metrics.py
parentfda2bc369960da0f112b275d94bcf0b4adc3a3c7 (diff)
new fixes
Diffstat (limited to 'src/python/gudhi/representations/metrics.py')
-rw-r--r--src/python/gudhi/representations/metrics.py27
1 files changed, 12 insertions, 15 deletions
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py
index c5439a67..0659b457 100644
--- a/src/python/gudhi/representations/metrics.py
+++ b/src/python/gudhi/representations/metrics.py
@@ -10,17 +10,9 @@
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import pairwise_distances
-from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance
from gudhi.hera import wasserstein_distance as hera_wasserstein_distance
from .preprocessing import Padding
-try:
- from .. import bottleneck_distance
- USE_GUDHI = True
-except ImportError:
- USE_GUDHI = False
- print("Gudhi built without CGAL: BottleneckDistance will return a null matrix")
-
#############################################
# Metrics ###################################
#############################################
@@ -111,9 +103,13 @@ def pairwise_persistence_diagram_distances(X, Y=None, metric="bottleneck", **kwa
YY = None if Y is None else np.reshape(np.arange(len(Y)), [-1,1])
if metric == "bottleneck":
return pairwise_distances(XX, YY, metric=sklearn_wrapper(bottleneck_distance, X, Y, **kwargs))
- elif metric == "wasserstein" or metric == "pot_wasserstein":
- return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, X, Y, **kwargs))
- elif metric == "hera_wasserstein":
+ elif metric == "pot_wasserstein":
+ try:
+ from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance
+ return pairwise_distances(XX, YY, metric=sklearn_wrapper(pot_wasserstein_distance, X, Y, **kwargs))
+ except ImportError:
+ print("Gudhi built without POT")
+ elif metric == "wasserstein" or metric == "hera_wasserstein":
return pairwise_distances(XX, YY, metric=sklearn_wrapper(hera_wasserstein_distance, X, Y, **kwargs))
elif metric == "sliced_wasserstein":
return pairwise_distances(XX, YY, metric=sklearn_wrapper(sliced_wasserstein_distance, X, Y, **kwargs))
@@ -192,16 +188,17 @@ class BottleneckDistance(BaseEstimator, TransformerMixin):
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise bottleneck distances.
"""
- if not USE_GUDHI:
- print("Gudhi built without CGAL: returning a null matrix")
- Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric="bottleneck", e=self.epsilon) if USE_GUDHI else np.zeros((len(X), len(self.diagrams_)))
+ try:
+ Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric="bottleneck", e=self.epsilon)
+ except ImportError:
+ print("Gudhi built without CGAL")
return Xfit
class WassersteinDistance(BaseEstimator, TransformerMixin):
"""
This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams.
"""
- def __init__(self, order=2, internal_p=2, mode="pot", delta=0.0001):
+ def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01):
"""
Constructor for the WassersteinDistance class.