summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py28
1 files changed, 18 insertions, 10 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index a76b6ce7..42bf3299 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -97,6 +97,13 @@ def test_warn_infty():
assert (m is None)
+def _to_set(X):
+ return { (i, j) for i, j in X }
+
+def _same_permuted(X, Y):
+ return _to_set(X) == _to_set(Y)
+
+
def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):
diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])
@@ -141,15 +148,16 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
if test_matching:
match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1]
- assert np.array_equal(match, [])
+ # Accept [] or np.array of shape (2, 0)
+ assert len(match) == 0
match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
- assert np.array_equal(match, [])
+ assert len(match) == 0
match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1]
- assert np.array_equal(match , [[-1, 0], [-1, 1]])
+ assert _same_permuted(match, [[-1, 0], [-1, 1]])
match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
- assert np.array_equal(match , [[0, -1], [1, -1]])
+ assert _same_permuted(match, [[0, -1], [1, -1]])
match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
- assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]])
+ assert _same_permuted(match, [[0, 0], [1, 1], [2, -1]])
if test_matching and test_infinity:
diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]])
@@ -158,7 +166,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
diag10 = np.array([[0,1], [-np.inf, -np.inf], [np.inf, np.inf]])
match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1]
- assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]])
+ assert _same_permuted(match, [[0, -1], [-1,0], [-1, 1], [1, 2]])
match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1]
assert (match is None)
cost, match = wasserstein_distance(diag7, emptydiag, matching=True, internal_p=2., order=2.3)
@@ -172,10 +180,10 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
assert (match is None)
cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.)
assert (cost == 1)
- assert (match == [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway.
+ assert _same_permuted(match, [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway.
cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.)
assert (cost == 0.)
- assert (match == [[0, -1], [1, -1]])
+ assert _same_permuted(match, [[0, -1], [1, -1]])
def hera_wrap(**extra):
@@ -196,6 +204,6 @@ def test_wasserstein_distance_pot():
def test_wasserstein_distance_hera():
- _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
- _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)
+ _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=True)
+ _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=True)