summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2021-04-20 19:06:56 +0200
committertlacombe <lacombe1993@gmail.com>2021-04-20 19:06:56 +0200
commit604b2cde0c7951c81d1c510f3038e2c65c19e6fe (patch)
treed2f22392f94fcb3c449453c79f773c2e56892ed0 /src/python/test/test_wasserstein_distance.py
parentbb0792ed7bfe9d718be3e8039e8fb89af6d160e5 (diff)
update doc and tests
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index df7acc91..121ba065 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -67,16 +67,25 @@ def test_handle_essential_parts():
def test_get_essential_parts():
- diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf],
+ diag1 = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf],
[np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]])
- res = _get_essential_parts(diag)
+ diag2 = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf]])
+
+ res = _get_essential_parts(diag1)
+ res2 = _get_essential_parts(diag2)
assert np.array_equal(res[0], [4, 5])
assert np.array_equal(res[1], [2, 3])
assert np.array_equal(res[2], [8, 9])
assert np.array_equal(res[3], [6] )
assert np.array_equal(res[4], [7] )
+ assert np.array_equal(res2[0], [] )
+ assert np.array_equal(res2[1], [2, 3])
+ assert np.array_equal(res2[2], [] )
+ assert np.array_equal(res2[3], [] )
+ assert np.array_equal(res2[4], [] )
+
def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):
diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
@@ -152,7 +161,7 @@ def pot_wrap(**extra):
return fun
def test_wasserstein_distance_pot():
- _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True)
+ _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) # pot with its standard args
_basic_wasserstein(pot_wrap(enable_autodiff=True, keep_essential_parts=False), 1e-15, test_infinity=False, test_matching=False)
def test_wasserstein_distance_hera():