summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/include')
-rw-r--r--geom_matching/wasserstein/include/basic_defs_ws.h2
-rw-r--r--geom_matching/wasserstein/include/wasserstein.h14
2 files changed, 16 insertions, 0 deletions
diff --git a/geom_matching/wasserstein/include/basic_defs_ws.h b/geom_matching/wasserstein/include/basic_defs_ws.h
index 365c3bd..474af22 100644
--- a/geom_matching/wasserstein/include/basic_defs_ws.h
+++ b/geom_matching/wasserstein/include/basic_defs_ws.h
@@ -77,6 +77,7 @@ struct DiagramPoint
bool isNormal(void) const { return type == NORMAL; }
double getRealX() const; // return the x-coord
double getRealY() const; // return the y-coord
+ double persistenceLp(const double p) const;
#ifndef FOR_R_TDA
friend std::ostream& operator<<(std::ostream& output, const DiagramPoint p);
#endif
@@ -92,6 +93,7 @@ double sqrDist(const Point& a, const Point& b);
double dist(const Point& a, const Point& b);
double distLInf(const DiagramPoint& a, const DiagramPoint& b);
double distLp(const DiagramPoint& a, const DiagramPoint& b, const double p);
+double persistenceLp(const DiagramPoint& a, const double p);
template<typename DiagPointContainer>
double getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B)
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h
index 1d8f35b..155d79d 100644
--- a/geom_matching/wasserstein/include/wasserstein.h
+++ b/geom_matching/wasserstein/include/wasserstein.h
@@ -89,10 +89,14 @@ double wassersteinDist(PairContainer& A, PairContainer& B, const double q, const
return 0.0;
}
+ bool a_empty { true };
+ bool b_empty { true };
+
std::vector<DiagramPoint> dgmA, dgmB;
// loop over A, add projections of A-points to corresponding positions
// in B-vector
for(auto& pairA : A) {
+ a_empty = false;
double x = pairA.first;
double y = pairA.second;
dgmA.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL));
@@ -100,11 +104,21 @@ double wassersteinDist(PairContainer& A, PairContainer& B, const double q, const
}
// the same for B
for(auto& pairB : B) {
+ b_empty = false;
double x = pairB.first;
double y = pairB.second;
dgmA.push_back(DiagramPoint(x, y, DiagramPoint::DIAG));
dgmB.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL));
}
+
+ if (a_empty && b_empty)
+ return 0.0;
+
+ if (a_empty)
+ dgmA.clear();
+
+ if (b_empty)
+ dgmB.clear();
return wassersteinDistVec(dgmA, dgmB, q, delta, _internal_p, _initialEpsilon, _epsFactor);
}