From bd3300343726981dbb7b7f45d1cabc9d781e28a1 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Mon, 24 Apr 2017 16:49:37 +0600 Subject: Empty diagram bug for Wasserstein fixed --- geom_bottleneck/bottleneck/src/bottleneck.cpp | 4 ++-- geom_matching/wasserstein/include/basic_defs_ws.h | 2 ++ geom_matching/wasserstein/include/wasserstein.h | 14 ++++++++++++++ geom_matching/wasserstein/src/basic_defs.cpp | 14 ++++++++++++++ geom_matching/wasserstein/src/wasserstein.cpp | 19 +++++++++++++++++++ 5 files changed, 51 insertions(+), 2 deletions(-) diff --git a/geom_bottleneck/bottleneck/src/bottleneck.cpp b/geom_bottleneck/bottleneck/src/bottleneck.cpp index da0c425..82fdcfe 100644 --- a/geom_bottleneck/bottleneck/src/bottleneck.cpp +++ b/geom_bottleneck/bottleneck/src/bottleneck.cpp @@ -208,8 +208,8 @@ std::pair bottleneckDistApproxIntervalHeur(DiagramPointSet& A, D DiagramPointSet sampledA, sampledB; sampleDiagramForHeur(A, sampledA); sampleDiagramForHeur(B, sampledB); - //std::cout << "A : " << A.size() << ", sampled: " << sampledA.size() << std::endl; - //std::cout << "B : " << B.size() << ", sampled: " << sampledB.size() << std::endl; + std::cout << "A : " << A.size() << ", sampled: " << sampledA.size() << std::endl; + std::cout << "B : " << B.size() << ", sampled: " << sampledB.size() << std::endl; std::pair initGuess = bottleneckDistApproxInterval(sampledA, sampledB, epsilon); //std::cout << "initial guess: " << initGuess.first << ", " << initGuess.second << std::endl; return bottleneckDistApproxIntervalWithInitial(A, B, epsilon, initGuess); diff --git a/geom_matching/wasserstein/include/basic_defs_ws.h b/geom_matching/wasserstein/include/basic_defs_ws.h index 365c3bd..474af22 100644 --- a/geom_matching/wasserstein/include/basic_defs_ws.h +++ b/geom_matching/wasserstein/include/basic_defs_ws.h @@ -77,6 +77,7 @@ struct DiagramPoint bool isNormal(void) const { return type == NORMAL; } double getRealX() const; // return the x-coord double getRealY() const; // return the y-coord + double persistenceLp(const double p) const; #ifndef FOR_R_TDA friend std::ostream& operator<<(std::ostream& output, const DiagramPoint p); #endif @@ -92,6 +93,7 @@ double sqrDist(const Point& a, const Point& b); double dist(const Point& a, const Point& b); double distLInf(const DiagramPoint& a, const DiagramPoint& b); double distLp(const DiagramPoint& a, const DiagramPoint& b, const double p); +double persistenceLp(const DiagramPoint& a, const double p); template double getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B) diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h index 1d8f35b..155d79d 100644 --- a/geom_matching/wasserstein/include/wasserstein.h +++ b/geom_matching/wasserstein/include/wasserstein.h @@ -89,10 +89,14 @@ double wassersteinDist(PairContainer& A, PairContainer& B, const double q, const return 0.0; } + bool a_empty { true }; + bool b_empty { true }; + std::vector dgmA, dgmB; // loop over A, add projections of A-points to corresponding positions // in B-vector for(auto& pairA : A) { + a_empty = false; double x = pairA.first; double y = pairA.second; dgmA.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL)); @@ -100,11 +104,21 @@ double wassersteinDist(PairContainer& A, PairContainer& B, const double q, const } // the same for B for(auto& pairB : B) { + b_empty = false; double x = pairB.first; double y = pairB.second; dgmA.push_back(DiagramPoint(x, y, DiagramPoint::DIAG)); dgmB.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL)); } + + if (a_empty && b_empty) + return 0.0; + + if (a_empty) + dgmA.clear(); + + if (b_empty) + dgmB.clear(); return wassersteinDistVec(dgmA, dgmB, q, delta, _internal_p, _initialEpsilon, _epsFactor); } diff --git a/geom_matching/wasserstein/src/basic_defs.cpp b/geom_matching/wasserstein/src/basic_defs.cpp index a46e6aa..ec5dcec 100644 --- a/geom_matching/wasserstein/src/basic_defs.cpp +++ b/geom_matching/wasserstein/src/basic_defs.cpp @@ -104,6 +104,20 @@ double distLp(const DiagramPoint& a, const DiagramPoint& b, const double p) } +double DiagramPoint::persistenceLp(const double p) const +{ + if (isDiagonal()) + return 0.0; + else { + double u { 0.5 * (getRealY() + getRealX()) }; + DiagramPoint a_proj(u, u, DiagramPoint::DIAG); + return distLp(*this, a_proj, p); + } + + +} + + #ifndef FOR_R_TDA std::ostream& operator<<(std::ostream& output, const DiagramPoint p) { diff --git a/geom_matching/wasserstein/src/wasserstein.cpp b/geom_matching/wasserstein/src/wasserstein.cpp index 8776b5f..b8a75ef 100644 --- a/geom_matching/wasserstein/src/wasserstein.cpp +++ b/geom_matching/wasserstein/src/wasserstein.cpp @@ -74,6 +74,25 @@ double wassersteinDistVec(const std::vector& A, throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(_epsFactor)); } + if (A.empty() && B.empty()) + return 0.0; + + if (A.empty()) { + double result { 0.0 } ; + for(const auto& pt : B) { + result += pt.persistenceLp(_internal_p); + } + return result; + } + + if (B.empty()) { + double result { 0.0 } ; + for(const auto& pt : A) { + result += pt.persistenceLp(_internal_p); + } + return result; + } + #ifdef GAUSS_SEIDEL_AUCTION AuctionRunnerGS auction(A, B, q, delta, _internal_p, _initialEpsilon, _epsFactor); -- cgit v1.2.3