diff options
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 28 |
1 files changed, 18 insertions, 10 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index a76b6ce7..42bf3299 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -97,6 +97,13 @@ def test_warn_infty(): assert (m is None) +def _to_set(X): + return { (i, j) for i, j in X } + +def _same_permuted(X, Y): + return _to_set(X) == _to_set(Y) + + def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) @@ -141,15 +148,16 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat if test_matching: match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1] - assert np.array_equal(match, []) + # Accept [] or np.array of shape (2, 0) + assert len(match) == 0 match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1] - assert np.array_equal(match, []) + assert len(match) == 0 match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1] - assert np.array_equal(match , [[-1, 0], [-1, 1]]) + assert _same_permuted(match, [[-1, 0], [-1, 1]]) match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1] - assert np.array_equal(match , [[0, -1], [1, -1]]) + assert _same_permuted(match, [[0, -1], [1, -1]]) match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1] - assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]]) + assert _same_permuted(match, [[0, 0], [1, 1], [2, -1]]) if test_matching and test_infinity: diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]]) @@ -158,7 +166,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat diag10 = np.array([[0,1], [-np.inf, -np.inf], [np.inf, np.inf]]) match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1] - assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]]) + assert _same_permuted(match, [[0, -1], [-1,0], [-1, 1], [1, 2]]) match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1] assert (match is None) cost, match = wasserstein_distance(diag7, emptydiag, matching=True, internal_p=2., order=2.3) @@ -172,10 +180,10 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat assert (match is None) cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.) assert (cost == 1) - assert (match == [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway. + assert _same_permuted(match, [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway. cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.) assert (cost == 0.) - assert (match == [[0, -1], [1, -1]]) + assert _same_permuted(match, [[0, -1], [1, -1]]) def hera_wrap(**extra): @@ -196,6 +204,6 @@ def test_wasserstein_distance_pot(): def test_wasserstein_distance_hera(): - _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) - _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) + _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=True) + _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=True) |