diff options
Diffstat (limited to 'geom_matching/wasserstein/include/wasserstein.h')
-rw-r--r-- | geom_matching/wasserstein/include/wasserstein.h | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h index 38ac6bd..fcc9e74 100644 --- a/geom_matching/wasserstein/include/wasserstein.h +++ b/geom_matching/wasserstein/include/wasserstein.h @@ -49,6 +49,15 @@ double wassersteinDistVec(const std::vector<DiagramPoint>& A, const double _initialEpsilon = 0.0, const double _epsFactor = 0.0); +// get Wasserstein cost (distance^q) between two persistence diagrams +double wassersteinCostVec(const std::vector<DiagramPoint>& A, + const std::vector<DiagramPoint>& B, + const double q, + const double delta, + const double _internal_p = std::numeric_limits<double>::infinity(), + const double _initialEpsilon = 0.0, + const double _epsFactor = 0.0); + // compare as multisets template<class PairContainer> @@ -98,6 +107,33 @@ double wassersteinDist(PairContainer& A, PairContainer& B, const double q, const return wassersteinDistVec(dgmA, dgmB, q, delta, _internal_p, _initialEpsilon, _epsFactor); } +template<class PairContainer> +double wassersteinCost(PairContainer& A, PairContainer& B, const double q, const double delta, const double _internal_p = std::numeric_limits<double>::infinity(), const double _initialEpsilon = 0.0, const double _epsFactor = 0.0) +{ + if (areEqual(A, B)) { + return 0.0; + } + + std::vector<DiagramPoint> dgmA, dgmB; + // loop over A, add projections of A-points to corresponding positions + // in B-vector + for(auto& pairA : A) { + double x = pairA.first; + double y = pairA.second; + dgmA.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL)); + dgmB.push_back(DiagramPoint(x, y, DiagramPoint::DIAG)); + } + // the same for B + for(auto& pairB : B) { + double x = pairB.first; + double y = pairB.second; + dgmA.push_back(DiagramPoint(x, y, DiagramPoint::DIAG)); + dgmB.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL)); + } + + return wassersteinCostVec(dgmA, dgmB, q, delta, _internal_p, _initialEpsilon, _epsFactor); +} + // fill in result with points from file fname // return false if file can't be opened |