diff options
author | Arnur Nigmetov <anigmetov@lbl.gov> | 2021-04-21 13:18:41 -0700 |
---|---|---|
committer | Arnur Nigmetov <anigmetov@lbl.gov> | 2021-04-21 13:18:41 -0700 |
commit | b528c4067a8aac346eb307d3c23b82d5953cfe2d (patch) | |
tree | c46898aa945cc94fddd1e1c75827878623ac6d28 /wasserstein/tests | |
parent | 7af824834e97c703c7724a6649d059639c8a1e36 (diff) |
Fix Wasserstein: ignore diagonal points in are_equal.
Diffstat (limited to 'wasserstein/tests')
-rw-r--r-- | wasserstein/tests/test_hera_wasserstein.cpp | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/wasserstein/tests/test_hera_wasserstein.cpp b/wasserstein/tests/test_hera_wasserstein.cpp index 621dd5a..6f5de3b 100644 --- a/wasserstein/tests/test_hera_wasserstein.cpp +++ b/wasserstein/tests/test_hera_wasserstein.cpp @@ -111,6 +111,18 @@ TEST_CASE("simple cases", "wasserstein_dist") } + SECTION("trivial: two diagrams differing by diagonal point") { + + diagram_A.emplace_back(0.0, 1.0); + diagram_B.emplace_back(0.0, 0.0); + diagram_B.emplace_back(0.0, 1.0); + + double d1 = hera::wasserstein_cost<>(diagram_A, diagram_B, params); + double d2 = hera::wasserstein_cost<>(diagram_B, diagram_A, params); + REQUIRE( fabs(d2) <= 0.00000000001 ); + REQUIRE( fabs(d1) <= 0.00000000001 ); + } + } |