summaryrefslogtreecommitdiff
path: root/src/python/test/test_representations.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test/test_representations.py')
-rwxr-xr-xsrc/python/test/test_representations.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py
index 43c914f3..8ebd7888 100755
--- a/src/python/test/test_representations.py
+++ b/src/python/test/test_representations.py
@@ -41,7 +41,15 @@ def test_multiple():
assert d1 == pytest.approx(d2)
assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal)
d1 = pairwise_persistence_diagram_distances(l1, l2, metric="wasserstein", order=2, internal_p=2)
- d2 = WassersteinDistance(order=2, internal_p=2, n_jobs=4).fit(l2).transform(l1)
+ mode = ""
+ try:
+ import ot
+ mode = "pot"
+ except ImportError:
+ print("POT is not available, try with hera")
+ mode = "hera"
+
+ d2 = WassersteinDistance(order=2, internal_p=2, mode=mode, n_jobs=4).fit(l2).transform(l1)
print(d1.shape, d2.shape)
assert d1 == pytest.approx(d2, rel=0.02)