summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2022-11-16 16:32:08 +0100
committerMarc Glisse <marc.glisse@inria.fr>2022-11-16 16:32:08 +0100
commit50a772551a29fe0178ff05e3390754ac89cf5826 (patch)
treea332521632541340794dbf6e6285e4ab207f89da
parent03e20909d4219d177b512b5f798ae5a5552ae17d (diff)
Tests more resilient to arbitrary order
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 8700107b..72514099 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -96,6 +96,13 @@ def test_warn_infty():
assert (m is None)
+def _to_set(X):
+ return { (i, j) for i, j in X }
+
+def _same_permuted(X, Y):
+ return _to_set(X) == _to_set(Y)
+
+
def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):
diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])
@@ -145,11 +152,11 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
assert len(match) == 0
match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1]
- assert np.array_equal(match , [[-1, 0], [-1, 1]])
+ assert _same_permuted(match, [[-1, 0], [-1, 1]])
match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
- assert np.array_equal(match , [[0, -1], [1, -1]])
+ assert _same_permuted(match, [[0, -1], [1, -1]])
match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
- assert {(i,j) for i,j in match} == {(0, 0), (1, 1), (2, -1)}
+ assert _same_permuted(match, [[0, 0], [1, 1], [2, -1]])
if test_matching and test_infinity:
diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]])
@@ -158,7 +165,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
diag10 = np.array([[0,1], [-np.inf, -np.inf], [np.inf, np.inf]])
match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1]
- assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]])
+ assert _same_permuted(match, [[0, -1], [-1,0], [-1, 1], [1, 2]])
match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1]
assert (match is None)
cost, match = wasserstein_distance(diag7, emptydiag, matching=True, internal_p=2., order=2.3)
@@ -172,10 +179,10 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
assert (match is None)
cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.)
assert (cost == 1)
- assert {(i,j) for i,j in match} == {(0, -1),(1, -1),(-1, 0), (-1, 1), (-1, 2)} # type 4 and 5 are match to the diag anyway.
+ assert _same_permuted(match, [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway.
cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.)
assert (cost == 0.)
- assert np.array_equal(match, [[0, -1], [1, -1]])
+ assert _same_permuted(match, [[0, -1], [1, -1]])
def hera_wrap(**extra):