diff options
Diffstat (limited to 'wasserstein/include')
-rw-r--r-- | wasserstein/include/wasserstein.h | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/wasserstein/include/wasserstein.h b/wasserstein/include/wasserstein.h index 3d9fd17..142fcbb 100644 --- a/wasserstein/include/wasserstein.h +++ b/wasserstein/include/wasserstein.h @@ -73,21 +73,19 @@ namespace ws template<class PairContainer> inline bool are_equal(const PairContainer& dgm1, const PairContainer& dgm2) { - if (dgm1.size() != dgm2.size()) { - return false; - } - using Traits = typename hera::DiagramTraits<PairContainer>; using PointType = typename Traits::PointType; std::map<PointType, int> m1, m2; for(auto&& pair1 : dgm1) { - m1[pair1]++; + if (Traits::get_x(pair1) != Traits::get_y(pair1)) + m1[pair1]++; } for(auto&& pair2 : dgm2) { - m2[pair2]++; + if (Traits::get_x(pair2) != Traits::get_y(pair2)) + m2[pair2]++; } return m1 == m2; |