diff options
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 1cac3e1a..8700107b 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -149,7 +149,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1] assert np.array_equal(match , [[0, -1], [1, -1]]) match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1] - assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]]) + assert {(i,j) for i,j in match} == {(0, 0), (1, 1), (2, -1)} if test_matching and test_infinity: diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]]) |