summaryrefslogtreecommitdiff
path: root/wasserstein/include/wasserstein.h
diff options
context:
space:
mode:
Diffstat (limited to 'wasserstein/include/wasserstein.h')
-rw-r--r--wasserstein/include/wasserstein.h10
1 files changed, 4 insertions, 6 deletions
diff --git a/wasserstein/include/wasserstein.h b/wasserstein/include/wasserstein.h
index 3d9fd17..142fcbb 100644
--- a/wasserstein/include/wasserstein.h
+++ b/wasserstein/include/wasserstein.h
@@ -73,21 +73,19 @@ namespace ws
template<class PairContainer>
inline bool are_equal(const PairContainer& dgm1, const PairContainer& dgm2)
{
- if (dgm1.size() != dgm2.size()) {
- return false;
- }
-
using Traits = typename hera::DiagramTraits<PairContainer>;
using PointType = typename Traits::PointType;
std::map<PointType, int> m1, m2;
for(auto&& pair1 : dgm1) {
- m1[pair1]++;
+ if (Traits::get_x(pair1) != Traits::get_y(pair1))
+ m1[pair1]++;
}
for(auto&& pair2 : dgm2) {
- m2[pair2]++;
+ if (Traits::get_x(pair2) != Traits::get_y(pair2))
+ m2[pair2]++;
}
return m1 == m2;