summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include/wasserstein.h
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/include/wasserstein.h')
-rw-r--r--geom_matching/wasserstein/include/wasserstein.h36
1 files changed, 36 insertions, 0 deletions
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h
index 38ac6bd..fcc9e74 100644
--- a/geom_matching/wasserstein/include/wasserstein.h
+++ b/geom_matching/wasserstein/include/wasserstein.h
@@ -49,6 +49,15 @@ double wassersteinDistVec(const std::vector<DiagramPoint>& A,
const double _initialEpsilon = 0.0,
const double _epsFactor = 0.0);
+// get Wasserstein cost (distance^q) between two persistence diagrams
+double wassersteinCostVec(const std::vector<DiagramPoint>& A,
+ const std::vector<DiagramPoint>& B,
+ const double q,
+ const double delta,
+ const double _internal_p = std::numeric_limits<double>::infinity(),
+ const double _initialEpsilon = 0.0,
+ const double _epsFactor = 0.0);
+
// compare as multisets
template<class PairContainer>
@@ -98,6 +107,33 @@ double wassersteinDist(PairContainer& A, PairContainer& B, const double q, const
return wassersteinDistVec(dgmA, dgmB, q, delta, _internal_p, _initialEpsilon, _epsFactor);
}
+template<class PairContainer>
+double wassersteinCost(PairContainer& A, PairContainer& B, const double q, const double delta, const double _internal_p = std::numeric_limits<double>::infinity(), const double _initialEpsilon = 0.0, const double _epsFactor = 0.0)
+{
+ if (areEqual(A, B)) {
+ return 0.0;
+ }
+
+ std::vector<DiagramPoint> dgmA, dgmB;
+ // loop over A, add projections of A-points to corresponding positions
+ // in B-vector
+ for(auto& pairA : A) {
+ double x = pairA.first;
+ double y = pairA.second;
+ dgmA.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL));
+ dgmB.push_back(DiagramPoint(x, y, DiagramPoint::DIAG));
+ }
+ // the same for B
+ for(auto& pairB : B) {
+ double x = pairB.first;
+ double y = pairB.second;
+ dgmA.push_back(DiagramPoint(x, y, DiagramPoint::DIAG));
+ dgmB.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL));
+ }
+
+ return wassersteinCostVec(dgmA, dgmB, q, delta, _internal_p, _initialEpsilon, _epsFactor);
+}
+
// fill in result with points from file fname
// return false if file can't be opened