summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include/wasserstein.h
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/include/wasserstein.h')
-rw-r--r--geom_matching/wasserstein/include/wasserstein.h14
1 files changed, 14 insertions, 0 deletions
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h
index 1d8f35b..155d79d 100644
--- a/geom_matching/wasserstein/include/wasserstein.h
+++ b/geom_matching/wasserstein/include/wasserstein.h
@@ -89,10 +89,14 @@ double wassersteinDist(PairContainer& A, PairContainer& B, const double q, const
return 0.0;
}
+ bool a_empty { true };
+ bool b_empty { true };
+
std::vector<DiagramPoint> dgmA, dgmB;
// loop over A, add projections of A-points to corresponding positions
// in B-vector
for(auto& pairA : A) {
+ a_empty = false;
double x = pairA.first;
double y = pairA.second;
dgmA.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL));
@@ -100,11 +104,21 @@ double wassersteinDist(PairContainer& A, PairContainer& B, const double q, const
}
// the same for B
for(auto& pairB : B) {
+ b_empty = false;
double x = pairB.first;
double y = pairB.second;
dgmA.push_back(DiagramPoint(x, y, DiagramPoint::DIAG));
dgmB.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL));
}
+
+ if (a_empty && b_empty)
+ return 0.0;
+
+ if (a_empty)
+ dgmA.clear();
+
+ if (b_empty)
+ dgmB.clear();
return wassersteinDistVec(dgmA, dgmB, q, delta, _internal_p, _initialEpsilon, _epsFactor);
}