diff options
author | tlacombe <lacombe1993@gmail.com> | 2021-04-20 19:06:56 +0200 |
---|---|---|
committer | tlacombe <lacombe1993@gmail.com> | 2021-04-20 19:06:56 +0200 |
commit | 604b2cde0c7951c81d1c510f3038e2c65c19e6fe (patch) | |
tree | d2f22392f94fcb3c449453c79f773c2e56892ed0 /src | |
parent | bb0792ed7bfe9d718be3e8039e8fb89af6d160e5 (diff) |
update doc and tests
Diffstat (limited to 'src')
-rw-r--r-- | src/python/doc/wasserstein_distance_user.rst | 1 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 15 |
2 files changed, 13 insertions, 3 deletions
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index 091c9fd9..76eb1469 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -92,6 +92,7 @@ any matching has a cost +inf and thus can be considered to be optimal. In such a for j in dgm2_to_diagonal: print("point %s in dgm2 is matched to the diagonal" %j) + # An example where essential part cardinalities differ dgm3 = np.array([[1, 2], [0, np.inf]]) dgm4 = np.array([[1, 2], [0, np.inf], [1, np.inf]]) cost, matchings = gudhi.wasserstein.wasserstein_distance(dgm3, dgm4, matching=True, order=1, internal_p=2) diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index df7acc91..121ba065 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -67,16 +67,25 @@ def test_handle_essential_parts(): def test_get_essential_parts(): - diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf], + diag1 = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf], [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) - res = _get_essential_parts(diag) + diag2 = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf]]) + + res = _get_essential_parts(diag1) + res2 = _get_essential_parts(diag2) assert np.array_equal(res[0], [4, 5]) assert np.array_equal(res[1], [2, 3]) assert np.array_equal(res[2], [8, 9]) assert np.array_equal(res[3], [6] ) assert np.array_equal(res[4], [7] ) + assert np.array_equal(res2[0], [] ) + assert np.array_equal(res2[1], [2, 3]) + assert np.array_equal(res2[2], [] ) + assert np.array_equal(res2[3], [] ) + assert np.array_equal(res2[4], [] ) + def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) @@ -152,7 +161,7 @@ def pot_wrap(**extra): return fun def test_wasserstein_distance_pot(): - _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) + _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) # pot with its standard args _basic_wasserstein(pot_wrap(enable_autodiff=True, keep_essential_parts=False), 1e-15, test_infinity=False, test_matching=False) def test_wasserstein_distance_hera(): |