summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-02-06 22:14:08 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-02-06 22:14:08 +0100
commite8c908469cb4ac547d4fd46ad8daf5ee21739f58 (patch)
treec67aed9eb11a3fcdcefca0931203ff69eef82949 /src/python/test/test_wasserstein_distance.py
parent518c619d578dc6f168b6369417f15872e3cd0056 (diff)
pytest.approx
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py34
1 files changed, 17 insertions, 17 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 46a7079f..6a14c50e 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -11,6 +11,7 @@
from gudhi.wasserstein import wasserstein_distance as pot
from gudhi.hera import wasserstein_distance as hera
import numpy as np
+import pytest
__author__ = "Theo Lacombe"
__copyright__ = "Copyright (C) 2019 Inria"
@@ -24,32 +25,31 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True):
emptydiag = np.array([])
# We just need to handle positive numbers here
- def approx(a, b):
- f = 1 + delta
- return a <= b*f and b <= a*f
+ def approx(x):
+ return pytest.approx(x, rel=delta)
assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=1.) == 0.
assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=1.) == 0.
assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=2.) == 0.
assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=2.) == 0.
- assert approx(wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.), 2.)
- assert approx(wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.), 4.)
+ assert wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.) == approx(2.)
+ assert wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.) == approx(4.)
- assert approx(wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.), 5.) # thank you Pythagorician triplets
- assert approx(wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.), 2.5)
- assert approx(wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.), 3.5355339059327378)
+ assert wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.) == approx(5.) # thank you Pythagorician triplets
+ assert wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.) == approx(2.5)
+ assert wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.) == approx(3.5355339059327378)
- assert approx(wasserstein_distance(diag1, diag2, internal_p=2., order=1.) , 1.4453593023967701)
- assert approx(wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74), 0.9772734057168739)
+ assert wasserstein_distance(diag1, diag2, internal_p=2., order=1.) == approx(1.4453593023967701)
+ assert wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74) == approx(0.9772734057168739)
- assert approx(wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863), 3.141592214572228)
+ assert wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863) == approx(3.141592214572228)
- assert approx(wasserstein_distance(diag3, diag4, internal_p=1., order=1.), 3.)
- assert approx(wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.), 3.) # no diag matching here
- assert approx(wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.), np.sqrt(5))
- assert approx(wasserstein_distance(diag3, diag4, internal_p=1., order=2.), np.sqrt(5))
- assert approx(wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.), np.sqrt(5))
+ assert wasserstein_distance(diag3, diag4, internal_p=1., order=1.) == approx(3.)
+ assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.) == approx(3.) # no diag matching here
+ assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.) == approx(np.sqrt(5))
+ assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == approx(np.sqrt(5))
+ assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == approx(np.sqrt(5))
if(not test_infinity):
return
@@ -58,7 +58,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True):
diag6 = np.array([[7, 8], [4, 6], [3, np.inf]])
assert wasserstein_distance(diag4, diag5) == np.inf
- assert approx(wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf), 4.)
+ assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.)
def hera_wrap(delta):
def fun(*kargs,**kwargs):