From 55e2b2e55bc50a7cfea9ca1edfca632488cf016a Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Tue, 8 Dec 2020 15:49:43 +0100 Subject: Make representations tests work if CGAL and/or POT is not there --- .../diagram_vectorizations_distances_kernels.py | 19 +++++++++++++------ src/python/test/test_representations.py | 10 +++++++++- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/python/example/diagram_vectorizations_distances_kernels.py b/src/python/example/diagram_vectorizations_distances_kernels.py index c4a71a7a..2801576e 100755 --- a/src/python/example/diagram_vectorizations_distances_kernels.py +++ b/src/python/example/diagram_vectorizations_distances_kernels.py @@ -5,11 +5,11 @@ import numpy as np from sklearn.kernel_approximation import RBFSampler from sklearn.preprocessing import MinMaxScaler -from gudhi.representations import DiagramSelector, Clamping, Landscape, Silhouette, BettiCurve, ComplexPolynomial,\ +from gudhi.representations import (DiagramSelector, Clamping, Landscape, Silhouette, BettiCurve, ComplexPolynomial,\ TopologicalVector, DiagramScaler, BirthPersistenceTransform,\ PersistenceImage, PersistenceWeightedGaussianKernel, Entropy, \ PersistenceScaleSpaceKernel, SlicedWassersteinDistance,\ - SlicedWassersteinKernel, BottleneckDistance, PersistenceFisherKernel, WassersteinDistance + SlicedWassersteinKernel, PersistenceFisherKernel, WassersteinDistance) D1 = np.array([[0.,4.],[1.,2.],[3.,8.],[6.,8.], [0., np.inf], [5., np.inf]]) @@ -93,14 +93,21 @@ print("SW distance is " + str(sW(D1, D2))) SW = SlicedWassersteinKernel(num_directions=100, bandwidth=1.) print("SW kernel is " + str(SW(D1, D2))) -W = WassersteinDistance(order=2, internal_p=2, mode="pot") -print("Wasserstein distance (POT) is " + str(W(D1, D2))) +try: + W = WassersteinDistance(order=2, internal_p=2, mode="pot") + print("Wasserstein distance (POT) is " + str(W(D1, D2))) +except ImportError: + print("WassersteinDistance (POT) is not available, you may be missing pot.") W = WassersteinDistance(order=2, internal_p=2, mode="hera", delta=0.0001) print("Wasserstein distance (hera) is " + str(W(D1, D2))) -W = BottleneckDistance(epsilon=.001) -print("Bottleneck distance is " + str(W(D1, D2))) +try: + from gudhi.representations import BottleneckDistance + W = BottleneckDistance(epsilon=.001) + print("Bottleneck distance is " + str(W(D1, D2))) +except ImportError: + print("BottleneckDistance is not available, you may be missing CGAL.") PF = PersistenceFisherKernel(bandwidth_fisher=1., bandwidth=1.) print("PF kernel is " + str(PF(D1, D2))) diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py index 43c914f3..8ebd7888 100755 --- a/src/python/test/test_representations.py +++ b/src/python/test/test_representations.py @@ -41,7 +41,15 @@ def test_multiple(): assert d1 == pytest.approx(d2) assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal) d1 = pairwise_persistence_diagram_distances(l1, l2, metric="wasserstein", order=2, internal_p=2) - d2 = WassersteinDistance(order=2, internal_p=2, n_jobs=4).fit(l2).transform(l1) + mode = "" + try: + import ot + mode = "pot" + except ImportError: + print("POT is not available, try with hera") + mode = "hera" + + d2 = WassersteinDistance(order=2, internal_p=2, mode=mode, n_jobs=4).fit(l2).transform(l1) print(d1.shape, d2.shape) assert d1 == pytest.approx(d2, rel=0.02) -- cgit v1.2.3