diff options
Diffstat (limited to 'wasserstein')
-rw-r--r-- | wasserstein/include/wasserstein.h | 10 | ||||
-rw-r--r-- | wasserstein/tests/test_hera_wasserstein.cpp | 12 |
2 files changed, 16 insertions, 6 deletions
diff --git a/wasserstein/include/wasserstein.h b/wasserstein/include/wasserstein.h index 3d9fd17..142fcbb 100644 --- a/wasserstein/include/wasserstein.h +++ b/wasserstein/include/wasserstein.h @@ -73,21 +73,19 @@ namespace ws template<class PairContainer> inline bool are_equal(const PairContainer& dgm1, const PairContainer& dgm2) { - if (dgm1.size() != dgm2.size()) { - return false; - } - using Traits = typename hera::DiagramTraits<PairContainer>; using PointType = typename Traits::PointType; std::map<PointType, int> m1, m2; for(auto&& pair1 : dgm1) { - m1[pair1]++; + if (Traits::get_x(pair1) != Traits::get_y(pair1)) + m1[pair1]++; } for(auto&& pair2 : dgm2) { - m2[pair2]++; + if (Traits::get_x(pair2) != Traits::get_y(pair2)) + m2[pair2]++; } return m1 == m2; diff --git a/wasserstein/tests/test_hera_wasserstein.cpp b/wasserstein/tests/test_hera_wasserstein.cpp index 621dd5a..6f5de3b 100644 --- a/wasserstein/tests/test_hera_wasserstein.cpp +++ b/wasserstein/tests/test_hera_wasserstein.cpp @@ -111,6 +111,18 @@ TEST_CASE("simple cases", "wasserstein_dist") } + SECTION("trivial: two diagrams differing by diagonal point") { + + diagram_A.emplace_back(0.0, 1.0); + diagram_B.emplace_back(0.0, 0.0); + diagram_B.emplace_back(0.0, 1.0); + + double d1 = hera::wasserstein_cost<>(diagram_A, diagram_B, params); + double d2 = hera::wasserstein_cost<>(diagram_B, diagram_A, params); + REQUIRE( fabs(d2) <= 0.00000000001 ); + REQUIRE( fabs(d1) <= 0.00000000001 ); + } + } |