diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-02-06 22:14:08 +0100 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-02-06 22:14:08 +0100 |
commit | e8c908469cb4ac547d4fd46ad8daf5ee21739f58 (patch) | |
tree | c67aed9eb11a3fcdcefca0931203ff69eef82949 | |
parent | 518c619d578dc6f168b6369417f15872e3cd0056 (diff) |
pytest.approx
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 34 |
1 files changed, 17 insertions, 17 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 46a7079f..6a14c50e 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -11,6 +11,7 @@ from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np +import pytest __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" @@ -24,32 +25,31 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True): emptydiag = np.array([]) # We just need to handle positive numbers here - def approx(a, b): - f = 1 + delta - return a <= b*f and b <= a*f + def approx(x): + return pytest.approx(x, rel=delta) assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=1.) == 0. assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=1.) == 0. assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=2.) == 0. assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=2.) == 0. - assert approx(wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.), 2.) - assert approx(wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.), 4.) + assert wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.) == approx(2.) + assert wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.) == approx(4.) - assert approx(wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.), 5.) # thank you Pythagorician triplets - assert approx(wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.), 2.5) - assert approx(wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.), 3.5355339059327378) + assert wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.) == approx(5.) # thank you Pythagorician triplets + assert wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.) == approx(2.5) + assert wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.) == approx(3.5355339059327378) - assert approx(wasserstein_distance(diag1, diag2, internal_p=2., order=1.) , 1.4453593023967701) - assert approx(wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74), 0.9772734057168739) + assert wasserstein_distance(diag1, diag2, internal_p=2., order=1.) == approx(1.4453593023967701) + assert wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74) == approx(0.9772734057168739) - assert approx(wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863), 3.141592214572228) + assert wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863) == approx(3.141592214572228) - assert approx(wasserstein_distance(diag3, diag4, internal_p=1., order=1.), 3.) - assert approx(wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.), 3.) # no diag matching here - assert approx(wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.), np.sqrt(5)) - assert approx(wasserstein_distance(diag3, diag4, internal_p=1., order=2.), np.sqrt(5)) - assert approx(wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.), np.sqrt(5)) + assert wasserstein_distance(diag3, diag4, internal_p=1., order=1.) == approx(3.) + assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.) == approx(3.) # no diag matching here + assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.) == approx(np.sqrt(5)) + assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == approx(np.sqrt(5)) + assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == approx(np.sqrt(5)) if(not test_infinity): return @@ -58,7 +58,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True): diag6 = np.array([[7, 8], [4, 6], [3, np.inf]]) assert wasserstein_distance(diag4, diag5) == np.inf - assert approx(wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf), 4.) + assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.) def hera_wrap(delta): def fun(*kargs,**kwargs): |