diff options
Diffstat (limited to 'geom_matching/wasserstein/src')
-rw-r--r-- | geom_matching/wasserstein/src/basic_defs.cpp | 14 | ||||
-rw-r--r-- | geom_matching/wasserstein/src/wasserstein.cpp | 19 |
2 files changed, 33 insertions, 0 deletions
diff --git a/geom_matching/wasserstein/src/basic_defs.cpp b/geom_matching/wasserstein/src/basic_defs.cpp index a46e6aa..ec5dcec 100644 --- a/geom_matching/wasserstein/src/basic_defs.cpp +++ b/geom_matching/wasserstein/src/basic_defs.cpp @@ -104,6 +104,20 @@ double distLp(const DiagramPoint& a, const DiagramPoint& b, const double p) } +double DiagramPoint::persistenceLp(const double p) const +{ + if (isDiagonal()) + return 0.0; + else { + double u { 0.5 * (getRealY() + getRealX()) }; + DiagramPoint a_proj(u, u, DiagramPoint::DIAG); + return distLp(*this, a_proj, p); + } + + +} + + #ifndef FOR_R_TDA std::ostream& operator<<(std::ostream& output, const DiagramPoint p) { diff --git a/geom_matching/wasserstein/src/wasserstein.cpp b/geom_matching/wasserstein/src/wasserstein.cpp index 8776b5f..b8a75ef 100644 --- a/geom_matching/wasserstein/src/wasserstein.cpp +++ b/geom_matching/wasserstein/src/wasserstein.cpp @@ -74,6 +74,25 @@ double wassersteinDistVec(const std::vector<DiagramPoint>& A, throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(_epsFactor)); } + if (A.empty() && B.empty()) + return 0.0; + + if (A.empty()) { + double result { 0.0 } ; + for(const auto& pt : B) { + result += pt.persistenceLp(_internal_p); + } + return result; + } + + if (B.empty()) { + double result { 0.0 } ; + for(const auto& pt : A) { + result += pt.persistenceLp(_internal_p); + } + return result; + } + #ifdef GAUSS_SEIDEL_AUCTION AuctionRunnerGS auction(A, B, q, delta, _internal_p, _initialEpsilon, _epsFactor); |