summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2020-12-08 15:49:43 +0100
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2020-12-08 15:49:43 +0100
commit55e2b2e55bc50a7cfea9ca1edfca632488cf016a (patch)
tree45d4a828a62e252d26361f0ed3e2824bb303e482
parentaa5ba4c9c4c72cfe650b82952785f7ae7955af5e (diff)
Make representations tests work if CGAL and/or POT is not there
-rwxr-xr-xsrc/python/example/diagram_vectorizations_distances_kernels.py19
-rwxr-xr-xsrc/python/test/test_representations.py10
2 files changed, 22 insertions, 7 deletions
diff --git a/src/python/example/diagram_vectorizations_distances_kernels.py b/src/python/example/diagram_vectorizations_distances_kernels.py
index c4a71a7a..2801576e 100755
--- a/src/python/example/diagram_vectorizations_distances_kernels.py
+++ b/src/python/example/diagram_vectorizations_distances_kernels.py
@@ -5,11 +5,11 @@ import numpy as np
from sklearn.kernel_approximation import RBFSampler
from sklearn.preprocessing import MinMaxScaler
-from gudhi.representations import DiagramSelector, Clamping, Landscape, Silhouette, BettiCurve, ComplexPolynomial,\
+from gudhi.representations import (DiagramSelector, Clamping, Landscape, Silhouette, BettiCurve, ComplexPolynomial,\
TopologicalVector, DiagramScaler, BirthPersistenceTransform,\
PersistenceImage, PersistenceWeightedGaussianKernel, Entropy, \
PersistenceScaleSpaceKernel, SlicedWassersteinDistance,\
- SlicedWassersteinKernel, BottleneckDistance, PersistenceFisherKernel, WassersteinDistance
+ SlicedWassersteinKernel, PersistenceFisherKernel, WassersteinDistance)
D1 = np.array([[0.,4.],[1.,2.],[3.,8.],[6.,8.], [0., np.inf], [5., np.inf]])
@@ -93,14 +93,21 @@ print("SW distance is " + str(sW(D1, D2)))
SW = SlicedWassersteinKernel(num_directions=100, bandwidth=1.)
print("SW kernel is " + str(SW(D1, D2)))
-W = WassersteinDistance(order=2, internal_p=2, mode="pot")
-print("Wasserstein distance (POT) is " + str(W(D1, D2)))
+try:
+ W = WassersteinDistance(order=2, internal_p=2, mode="pot")
+ print("Wasserstein distance (POT) is " + str(W(D1, D2)))
+except ImportError:
+ print("WassersteinDistance (POT) is not available, you may be missing pot.")
W = WassersteinDistance(order=2, internal_p=2, mode="hera", delta=0.0001)
print("Wasserstein distance (hera) is " + str(W(D1, D2)))
-W = BottleneckDistance(epsilon=.001)
-print("Bottleneck distance is " + str(W(D1, D2)))
+try:
+ from gudhi.representations import BottleneckDistance
+ W = BottleneckDistance(epsilon=.001)
+ print("Bottleneck distance is " + str(W(D1, D2)))
+except ImportError:
+ print("BottleneckDistance is not available, you may be missing CGAL.")
PF = PersistenceFisherKernel(bandwidth_fisher=1., bandwidth=1.)
print("PF kernel is " + str(PF(D1, D2)))
diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py
index 43c914f3..8ebd7888 100755
--- a/src/python/test/test_representations.py
+++ b/src/python/test/test_representations.py
@@ -41,7 +41,15 @@ def test_multiple():
assert d1 == pytest.approx(d2)
assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal)
d1 = pairwise_persistence_diagram_distances(l1, l2, metric="wasserstein", order=2, internal_p=2)
- d2 = WassersteinDistance(order=2, internal_p=2, n_jobs=4).fit(l2).transform(l1)
+ mode = ""
+ try:
+ import ot
+ mode = "pot"
+ except ImportError:
+ print("POT is not available, try with hera")
+ mode = "hera"
+
+ d2 = WassersteinDistance(order=2, internal_p=2, mode=mode, n_jobs=4).fit(l2).transform(l1)
print(d1.shape, d2.shape)
assert d1 == pytest.approx(d2, rel=0.02)