diff options
Diffstat (limited to 'geom_matching/wasserstein/src/wasserstein.cpp')
-rw-r--r-- | geom_matching/wasserstein/src/wasserstein.cpp | 44 |
1 files changed, 42 insertions, 2 deletions
diff --git a/geom_matching/wasserstein/src/wasserstein.cpp b/geom_matching/wasserstein/src/wasserstein.cpp index 2c536ed..8776b5f 100644 --- a/geom_matching/wasserstein/src/wasserstein.cpp +++ b/geom_matching/wasserstein/src/wasserstein.cpp @@ -74,6 +74,7 @@ double wassersteinDistVec(const std::vector<DiagramPoint>& A, throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(_epsFactor)); } + #ifdef GAUSS_SEIDEL_AUCTION AuctionRunnerGS auction(A, B, q, delta, _internal_p, _initialEpsilon, _epsFactor); #else @@ -124,12 +125,12 @@ double wassersteinCostVec(const std::vector<DiagramPoint>& A, return result; } -bool readDiagramPointSet(const std::string& fname, std::vector<std::pair<double, double>>& result) +bool readDiagramPointSet(const std::string& fname, PairVector& result) { return readDiagramPointSet(fname.c_str(), result); } -bool readDiagramPointSet(const char* fname, std::vector<std::pair<double, double>>& result) +bool readDiagramPointSet(const char* fname, PairVector& result) { size_t lineNumber { 0 }; result.clear(); @@ -173,4 +174,43 @@ bool readDiagramPointSet(const char* fname, std::vector<std::pair<double, double return true; } + +void removeDuplicates(PairVector& dgmA, PairVector& dgmB) +{ + std::map<std::pair<double, double>, int> mapA, mapB; + // copy points to maps + for(const auto& ptA : dgmA) { + mapA[ptA]++; + } + for(const auto& ptB : dgmB) { + mapB[ptB]++; + } + // clear vectors + dgmA.clear(); + dgmB.clear(); + // remove duplicates from maps + for(auto& pointMultiplicityPair : mapA) { + auto iterB = mapB.find(pointMultiplicityPair.first); + if (iterB != mapB.end()) { + int duplicateMultiplicity = std::min(pointMultiplicityPair.second, iterB->second); + pointMultiplicityPair.second -= duplicateMultiplicity; + iterB->second -= duplicateMultiplicity; + } + } + // copy points back to vectors + for(const auto& pointMultiplicityPairA : mapA) { + assert( pointMultiplicityPairA.second >= 0); + for(int i = 0; i < pointMultiplicityPairA.second; ++i) { + dgmA.push_back(pointMultiplicityPairA.first); + } + } + + for(const auto& pointMultiplicityPairB : mapB) { + assert( pointMultiplicityPairB.second >= 0); + for(int i = 0; i < pointMultiplicityPairB.second; ++i) { + dgmB.push_back(pointMultiplicityPairB.first); + } + } +} + } // end of namespace geom_ws |