diff options
author | tlacombe <lacombe1993@gmail.com> | 2020-07-07 12:37:51 +0200 |
---|---|---|
committer | tlacombe <lacombe1993@gmail.com> | 2020-07-07 12:37:51 +0200 |
commit | 42a399c273fde7c76ec23d2993957fcbb492ee79 (patch) | |
tree | 068c92bb31530fbdc0aa2029a7221be807b59c68 /src/python/test/test_wasserstein_distance.py | |
parent | e0eba14109e02676825f8c24563872a5b49c6120 (diff) |
correction mistake in tests
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index e50091e9..285b95c9 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -9,12 +9,13 @@ - YYYY/MM Author: Description of the modification """ -from gudhi.wasserstein.wasserstein import _proj_on_diag, _offdiag, _handle_essential_parts +from gudhi.wasserstein.wasserstein import _proj_on_diag, _offdiag, _handle_essential_parts, _get_essential_parts from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np import pytest + __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" __license__ = "MIT" @@ -56,8 +57,10 @@ def test_handle_essential_parts(): [-np.inf, np.inf], [-np.inf, np.inf]]) c, m = _handle_essential_parts(diag1, diag2, order=1) - assert c == pytest.approx(3, 0.0001) - assert np.array_equal(m, [[0,0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]]) + assert c == pytest.approx(2, 0.0001) # Note: here c is only the cost due to essential part (thus 2, not 3) + # Similarly, the matching only corresponds to essential parts. + assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, 6], [7, 7]]) + c, m = _handle_essential_parts(diag1, diag3, order=1) assert c == np.inf assert (m is None) @@ -68,11 +71,11 @@ def test_get_essential_parts(): [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) res = _get_essential_parts(diag) - assert res[0] == [4, 5] - assert res[1] == [2, 3] - assert res[2] == [8, 9] - assert res[3] == [6] - assert res[4] == [7] + assert np.array_equal(res[0], [4, 5]) + assert np.array_equal(res[1], [2, 3]) + assert np.array_equal(res[2], [8, 9]) + assert np.array_equal(res[3], [6] ) + assert np.array_equal(res[4], [7] ) def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): |