diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2022-11-16 16:32:08 +0100 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2022-11-16 16:32:08 +0100 |
commit | 50a772551a29fe0178ff05e3390754ac89cf5826 (patch) | |
tree | a332521632541340794dbf6e6285e4ab207f89da | |
parent | 03e20909d4219d177b512b5f798ae5a5552ae17d (diff) |
Tests more resilient to arbitrary order
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 8700107b..72514099 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -96,6 +96,13 @@ def test_warn_infty(): assert (m is None) +def _to_set(X): + return { (i, j) for i, j in X } + +def _same_permuted(X, Y): + return _to_set(X) == _to_set(Y) + + def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) @@ -145,11 +152,11 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1] assert len(match) == 0 match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1] - assert np.array_equal(match , [[-1, 0], [-1, 1]]) + assert _same_permuted(match, [[-1, 0], [-1, 1]]) match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1] - assert np.array_equal(match , [[0, -1], [1, -1]]) + assert _same_permuted(match, [[0, -1], [1, -1]]) match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1] - assert {(i,j) for i,j in match} == {(0, 0), (1, 1), (2, -1)} + assert _same_permuted(match, [[0, 0], [1, 1], [2, -1]]) if test_matching and test_infinity: diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]]) @@ -158,7 +165,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat diag10 = np.array([[0,1], [-np.inf, -np.inf], [np.inf, np.inf]]) match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1] - assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]]) + assert _same_permuted(match, [[0, -1], [-1,0], [-1, 1], [1, 2]]) match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1] assert (match is None) cost, match = wasserstein_distance(diag7, emptydiag, matching=True, internal_p=2., order=2.3) @@ -172,10 +179,10 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat assert (match is None) cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.) assert (cost == 1) - assert {(i,j) for i,j in match} == {(0, -1),(1, -1),(-1, 0), (-1, 1), (-1, 2)} # type 4 and 5 are match to the diag anyway. + assert _same_permuted(match, [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway. cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.) assert (cost == 0.) - assert np.array_equal(match, [[0, -1], [1, -1]]) + assert _same_permuted(match, [[0, -1], [1, -1]]) def hera_wrap(**extra): |