diff options
author | MathieuCarriere <mathieu.carriere3@gmail.com> | 2020-03-17 12:17:43 -0400 |
---|---|---|
committer | MathieuCarriere <mathieu.carriere3@gmail.com> | 2020-03-17 12:17:43 -0400 |
commit | 50427c9fe5c63f3bfe27270db0ea26480e73cb1a (patch) | |
tree | d4f867c6076ad4ae425a77c4346b370921e6251f /src/python/test | |
parent | 58d923b13afb9b18a2d5b028c6575baee691d182 (diff) | |
parent | 5fdbad3fdd350edc3a5340110f4c5332c73517b3 (diff) |
fix conflict
Diffstat (limited to 'src/python/test')
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 6a6b217b..0d70e11a 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -17,7 +17,7 @@ __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" __license__ = "MIT" -def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True): +def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) diag3 = np.array([[0, 2], [4, 6]]) @@ -51,14 +51,27 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True): assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == approx(np.sqrt(5)) assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == approx(np.sqrt(5)) - if(not test_infinity): - return + if test_infinity: + diag5 = np.array([[0, 3], [4, np.inf]]) + diag6 = np.array([[7, 8], [4, 6], [3, np.inf]]) - diag5 = np.array([[0, 3], [4, np.inf]]) - diag6 = np.array([[7, 8], [4, 6], [3, np.inf]]) + assert wasserstein_distance(diag4, diag5) == np.inf + assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.) + + + if test_matching: + match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1] + assert np.array_equal(match, []) + match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1] + assert np.array_equal(match, []) + match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1] + assert np.array_equal(match , [[-1, 0], [-1, 1]]) + match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1] + assert np.array_equal(match , [[0, -1], [1, -1]]) + match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1] + assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]]) + - assert wasserstein_distance(diag4, diag5) == np.inf - assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.) def hera_wrap(delta): def fun(*kargs,**kwargs): @@ -66,8 +79,8 @@ def hera_wrap(delta): return fun def test_wasserstein_distance_pot(): - _basic_wasserstein(pot, 1e-15, test_infinity=False) + _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) def test_wasserstein_distance_hera(): - _basic_wasserstein(hera_wrap(1e-12), 1e-12) - _basic_wasserstein(hera_wrap(.1), .1) + _basic_wasserstein(hera_wrap(1e-12), 1e-12, test_matching=False) + _basic_wasserstein(hera_wrap(.1), .1, test_matching=False) |