summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArnur Nigmetov <anigmetov@lbl.gov>2021-04-21 13:18:41 -0700
committerArnur Nigmetov <anigmetov@lbl.gov>2021-04-21 13:18:41 -0700
commitb528c4067a8aac346eb307d3c23b82d5953cfe2d (patch)
treec46898aa945cc94fddd1e1c75827878623ac6d28
parent7af824834e97c703c7724a6649d059639c8a1e36 (diff)
Fix Wasserstein: ignore diagonal points in are_equal.
-rw-r--r--wasserstein/include/wasserstein.h10
-rw-r--r--wasserstein/tests/test_hera_wasserstein.cpp12
2 files changed, 16 insertions, 6 deletions
diff --git a/wasserstein/include/wasserstein.h b/wasserstein/include/wasserstein.h
index 3d9fd17..142fcbb 100644
--- a/wasserstein/include/wasserstein.h
+++ b/wasserstein/include/wasserstein.h
@@ -73,21 +73,19 @@ namespace ws
template<class PairContainer>
inline bool are_equal(const PairContainer& dgm1, const PairContainer& dgm2)
{
- if (dgm1.size() != dgm2.size()) {
- return false;
- }
-
using Traits = typename hera::DiagramTraits<PairContainer>;
using PointType = typename Traits::PointType;
std::map<PointType, int> m1, m2;
for(auto&& pair1 : dgm1) {
- m1[pair1]++;
+ if (Traits::get_x(pair1) != Traits::get_y(pair1))
+ m1[pair1]++;
}
for(auto&& pair2 : dgm2) {
- m2[pair2]++;
+ if (Traits::get_x(pair2) != Traits::get_y(pair2))
+ m2[pair2]++;
}
return m1 == m2;
diff --git a/wasserstein/tests/test_hera_wasserstein.cpp b/wasserstein/tests/test_hera_wasserstein.cpp
index 621dd5a..6f5de3b 100644
--- a/wasserstein/tests/test_hera_wasserstein.cpp
+++ b/wasserstein/tests/test_hera_wasserstein.cpp
@@ -111,6 +111,18 @@ TEST_CASE("simple cases", "wasserstein_dist")
}
+ SECTION("trivial: two diagrams differing by diagonal point") {
+
+ diagram_A.emplace_back(0.0, 1.0);
+ diagram_B.emplace_back(0.0, 0.0);
+ diagram_B.emplace_back(0.0, 1.0);
+
+ double d1 = hera::wasserstein_cost<>(diagram_A, diagram_B, params);
+ double d2 = hera::wasserstein_cost<>(diagram_B, diagram_A, params);
+ REQUIRE( fabs(d2) <= 0.00000000001 );
+ REQUIRE( fabs(d1) <= 0.00000000001 );
+ }
+
}