From 5cda650aa878a9b5929c65eb83f431e117d39811 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Wed, 5 Apr 2017 21:58:27 +0200 Subject: Remove duplicate points for Wasserstein-1 --- .../wasserstein/src/auction_runner_gs.cpp | 2 +- geom_matching/wasserstein/src/wasserstein.cpp | 44 +++++++++++++++++++++- 2 files changed, 43 insertions(+), 3 deletions(-) (limited to 'geom_matching/wasserstein/src') diff --git a/geom_matching/wasserstein/src/auction_runner_gs.cpp b/geom_matching/wasserstein/src/auction_runner_gs.cpp index 5865325..e568813 100644 --- a/geom_matching/wasserstein/src/auction_runner_gs.cpp +++ b/geom_matching/wasserstein/src/auction_runner_gs.cpp @@ -245,7 +245,7 @@ double AuctionRunnerGS::getDistanceToQthPowerInternal(void) auto pA = bidders[bIdx]; assert( 0 <= biddersToItems[bIdx] and biddersToItems[bIdx] < static_cast(items.size()) ); auto pB = items[biddersToItems[bIdx]]; - std::cout << "pA = " << pA << ", pB = " << pB << ", pow(distLp(pA, pB, internal_p), wassersteinPower) = " << pow(distLp(pA, pB, internal_p), wassersteinPower) << ", dist = " << distLp(pA, pB, internal_p) << std::endl; + //std::cout << "pA = " << pA << ", pB = " << pB << ", pow(distLp(pA, pB, internal_p), wassersteinPower) = " << pow(distLp(pA, pB, internal_p), wassersteinPower) << ", dist = " << distLp(pA, pB, internal_p) << std::endl; result += pow(distLp(pA, pB, internal_p), wassersteinPower); } wassersteinCost = result; diff --git a/geom_matching/wasserstein/src/wasserstein.cpp b/geom_matching/wasserstein/src/wasserstein.cpp index 2c536ed..8776b5f 100644 --- a/geom_matching/wasserstein/src/wasserstein.cpp +++ b/geom_matching/wasserstein/src/wasserstein.cpp @@ -74,6 +74,7 @@ double wassersteinDistVec(const std::vector& A, throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(_epsFactor)); } + #ifdef GAUSS_SEIDEL_AUCTION AuctionRunnerGS auction(A, B, q, delta, _internal_p, _initialEpsilon, _epsFactor); #else @@ -124,12 +125,12 @@ double wassersteinCostVec(const std::vector& A, return result; } -bool readDiagramPointSet(const std::string& fname, std::vector>& result) +bool readDiagramPointSet(const std::string& fname, PairVector& result) { return readDiagramPointSet(fname.c_str(), result); } -bool readDiagramPointSet(const char* fname, std::vector>& result) +bool readDiagramPointSet(const char* fname, PairVector& result) { size_t lineNumber { 0 }; result.clear(); @@ -173,4 +174,43 @@ bool readDiagramPointSet(const char* fname, std::vector, int> mapA, mapB; + // copy points to maps + for(const auto& ptA : dgmA) { + mapA[ptA]++; + } + for(const auto& ptB : dgmB) { + mapB[ptB]++; + } + // clear vectors + dgmA.clear(); + dgmB.clear(); + // remove duplicates from maps + for(auto& pointMultiplicityPair : mapA) { + auto iterB = mapB.find(pointMultiplicityPair.first); + if (iterB != mapB.end()) { + int duplicateMultiplicity = std::min(pointMultiplicityPair.second, iterB->second); + pointMultiplicityPair.second -= duplicateMultiplicity; + iterB->second -= duplicateMultiplicity; + } + } + // copy points back to vectors + for(const auto& pointMultiplicityPairA : mapA) { + assert( pointMultiplicityPairA.second >= 0); + for(int i = 0; i < pointMultiplicityPairA.second; ++i) { + dgmA.push_back(pointMultiplicityPairA.first); + } + } + + for(const auto& pointMultiplicityPairB : mapB) { + assert( pointMultiplicityPairB.second >= 0); + for(int i = 0; i < pointMultiplicityPairB.second; ++i) { + dgmB.push_back(pointMultiplicityPairB.first); + } + } +} + } // end of namespace geom_ws -- cgit v1.2.3