summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 1cac3e1a..8700107b 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -149,7 +149,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
assert np.array_equal(match , [[0, -1], [1, -1]])
match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
- assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]])
+ assert {(i,j) for i,j in match} == {(0, 0), (1, 1), (2, -1)}
if test_matching and test_infinity:
diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]])