summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArnur Nigmetov <a.nigmetov@gmail.com>2017-04-24 16:49:37 +0600
committerArnur Nigmetov <a.nigmetov@gmail.com>2017-04-24 16:49:37 +0600
commitbd3300343726981dbb7b7f45d1cabc9d781e28a1 (patch)
tree1ea381d8265df271fb4215f75ae8b301253d2c18
parent0d3a01257675ee637c09046c739eaa235984b515 (diff)
Empty diagram bug for Wasserstein fixed
-rw-r--r--geom_bottleneck/bottleneck/src/bottleneck.cpp4
-rw-r--r--geom_matching/wasserstein/include/basic_defs_ws.h2
-rw-r--r--geom_matching/wasserstein/include/wasserstein.h14
-rw-r--r--geom_matching/wasserstein/src/basic_defs.cpp14
-rw-r--r--geom_matching/wasserstein/src/wasserstein.cpp19
5 files changed, 51 insertions, 2 deletions
diff --git a/geom_bottleneck/bottleneck/src/bottleneck.cpp b/geom_bottleneck/bottleneck/src/bottleneck.cpp
index da0c425..82fdcfe 100644
--- a/geom_bottleneck/bottleneck/src/bottleneck.cpp
+++ b/geom_bottleneck/bottleneck/src/bottleneck.cpp
@@ -208,8 +208,8 @@ std::pair<double, double> bottleneckDistApproxIntervalHeur(DiagramPointSet& A, D
DiagramPointSet sampledA, sampledB;
sampleDiagramForHeur(A, sampledA);
sampleDiagramForHeur(B, sampledB);
- //std::cout << "A : " << A.size() << ", sampled: " << sampledA.size() << std::endl;
- //std::cout << "B : " << B.size() << ", sampled: " << sampledB.size() << std::endl;
+ std::cout << "A : " << A.size() << ", sampled: " << sampledA.size() << std::endl;
+ std::cout << "B : " << B.size() << ", sampled: " << sampledB.size() << std::endl;
std::pair<double, double> initGuess = bottleneckDistApproxInterval(sampledA, sampledB, epsilon);
//std::cout << "initial guess: " << initGuess.first << ", " << initGuess.second << std::endl;
return bottleneckDistApproxIntervalWithInitial(A, B, epsilon, initGuess);
diff --git a/geom_matching/wasserstein/include/basic_defs_ws.h b/geom_matching/wasserstein/include/basic_defs_ws.h
index 365c3bd..474af22 100644
--- a/geom_matching/wasserstein/include/basic_defs_ws.h
+++ b/geom_matching/wasserstein/include/basic_defs_ws.h
@@ -77,6 +77,7 @@ struct DiagramPoint
bool isNormal(void) const { return type == NORMAL; }
double getRealX() const; // return the x-coord
double getRealY() const; // return the y-coord
+ double persistenceLp(const double p) const;
#ifndef FOR_R_TDA
friend std::ostream& operator<<(std::ostream& output, const DiagramPoint p);
#endif
@@ -92,6 +93,7 @@ double sqrDist(const Point& a, const Point& b);
double dist(const Point& a, const Point& b);
double distLInf(const DiagramPoint& a, const DiagramPoint& b);
double distLp(const DiagramPoint& a, const DiagramPoint& b, const double p);
+double persistenceLp(const DiagramPoint& a, const double p);
template<typename DiagPointContainer>
double getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B)
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h
index 1d8f35b..155d79d 100644
--- a/geom_matching/wasserstein/include/wasserstein.h
+++ b/geom_matching/wasserstein/include/wasserstein.h
@@ -89,10 +89,14 @@ double wassersteinDist(PairContainer& A, PairContainer& B, const double q, const
return 0.0;
}
+ bool a_empty { true };
+ bool b_empty { true };
+
std::vector<DiagramPoint> dgmA, dgmB;
// loop over A, add projections of A-points to corresponding positions
// in B-vector
for(auto& pairA : A) {
+ a_empty = false;
double x = pairA.first;
double y = pairA.second;
dgmA.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL));
@@ -100,11 +104,21 @@ double wassersteinDist(PairContainer& A, PairContainer& B, const double q, const
}
// the same for B
for(auto& pairB : B) {
+ b_empty = false;
double x = pairB.first;
double y = pairB.second;
dgmA.push_back(DiagramPoint(x, y, DiagramPoint::DIAG));
dgmB.push_back(DiagramPoint(x, y, DiagramPoint::NORMAL));
}
+
+ if (a_empty && b_empty)
+ return 0.0;
+
+ if (a_empty)
+ dgmA.clear();
+
+ if (b_empty)
+ dgmB.clear();
return wassersteinDistVec(dgmA, dgmB, q, delta, _internal_p, _initialEpsilon, _epsFactor);
}
diff --git a/geom_matching/wasserstein/src/basic_defs.cpp b/geom_matching/wasserstein/src/basic_defs.cpp
index a46e6aa..ec5dcec 100644
--- a/geom_matching/wasserstein/src/basic_defs.cpp
+++ b/geom_matching/wasserstein/src/basic_defs.cpp
@@ -104,6 +104,20 @@ double distLp(const DiagramPoint& a, const DiagramPoint& b, const double p)
}
+double DiagramPoint::persistenceLp(const double p) const
+{
+ if (isDiagonal())
+ return 0.0;
+ else {
+ double u { 0.5 * (getRealY() + getRealX()) };
+ DiagramPoint a_proj(u, u, DiagramPoint::DIAG);
+ return distLp(*this, a_proj, p);
+ }
+
+
+}
+
+
#ifndef FOR_R_TDA
std::ostream& operator<<(std::ostream& output, const DiagramPoint p)
{
diff --git a/geom_matching/wasserstein/src/wasserstein.cpp b/geom_matching/wasserstein/src/wasserstein.cpp
index 8776b5f..b8a75ef 100644
--- a/geom_matching/wasserstein/src/wasserstein.cpp
+++ b/geom_matching/wasserstein/src/wasserstein.cpp
@@ -74,6 +74,25 @@ double wassersteinDistVec(const std::vector<DiagramPoint>& A,
throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(_epsFactor));
}
+ if (A.empty() && B.empty())
+ return 0.0;
+
+ if (A.empty()) {
+ double result { 0.0 } ;
+ for(const auto& pt : B) {
+ result += pt.persistenceLp(_internal_p);
+ }
+ return result;
+ }
+
+ if (B.empty()) {
+ double result { 0.0 } ;
+ for(const auto& pt : A) {
+ result += pt.persistenceLp(_internal_p);
+ }
+ return result;
+ }
+
#ifdef GAUSS_SEIDEL_AUCTION
AuctionRunnerGS auction(A, B, q, delta, _internal_p, _initialEpsilon, _epsFactor);