summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/src/wasserstein.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/src/wasserstein.cpp')
-rw-r--r--geom_matching/wasserstein/src/wasserstein.cpp44
1 files changed, 42 insertions, 2 deletions
diff --git a/geom_matching/wasserstein/src/wasserstein.cpp b/geom_matching/wasserstein/src/wasserstein.cpp
index 2c536ed..8776b5f 100644
--- a/geom_matching/wasserstein/src/wasserstein.cpp
+++ b/geom_matching/wasserstein/src/wasserstein.cpp
@@ -74,6 +74,7 @@ double wassersteinDistVec(const std::vector<DiagramPoint>& A,
throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(_epsFactor));
}
+
#ifdef GAUSS_SEIDEL_AUCTION
AuctionRunnerGS auction(A, B, q, delta, _internal_p, _initialEpsilon, _epsFactor);
#else
@@ -124,12 +125,12 @@ double wassersteinCostVec(const std::vector<DiagramPoint>& A,
return result;
}
-bool readDiagramPointSet(const std::string& fname, std::vector<std::pair<double, double>>& result)
+bool readDiagramPointSet(const std::string& fname, PairVector& result)
{
return readDiagramPointSet(fname.c_str(), result);
}
-bool readDiagramPointSet(const char* fname, std::vector<std::pair<double, double>>& result)
+bool readDiagramPointSet(const char* fname, PairVector& result)
{
size_t lineNumber { 0 };
result.clear();
@@ -173,4 +174,43 @@ bool readDiagramPointSet(const char* fname, std::vector<std::pair<double, double
return true;
}
+
+void removeDuplicates(PairVector& dgmA, PairVector& dgmB)
+{
+ std::map<std::pair<double, double>, int> mapA, mapB;
+ // copy points to maps
+ for(const auto& ptA : dgmA) {
+ mapA[ptA]++;
+ }
+ for(const auto& ptB : dgmB) {
+ mapB[ptB]++;
+ }
+ // clear vectors
+ dgmA.clear();
+ dgmB.clear();
+ // remove duplicates from maps
+ for(auto& pointMultiplicityPair : mapA) {
+ auto iterB = mapB.find(pointMultiplicityPair.first);
+ if (iterB != mapB.end()) {
+ int duplicateMultiplicity = std::min(pointMultiplicityPair.second, iterB->second);
+ pointMultiplicityPair.second -= duplicateMultiplicity;
+ iterB->second -= duplicateMultiplicity;
+ }
+ }
+ // copy points back to vectors
+ for(const auto& pointMultiplicityPairA : mapA) {
+ assert( pointMultiplicityPairA.second >= 0);
+ for(int i = 0; i < pointMultiplicityPairA.second; ++i) {
+ dgmA.push_back(pointMultiplicityPairA.first);
+ }
+ }
+
+ for(const auto& pointMultiplicityPairB : mapB) {
+ assert( pointMultiplicityPairB.second >= 0);
+ for(int i = 0; i < pointMultiplicityPairB.second; ++i) {
+ dgmB.push_back(pointMultiplicityPairB.first);
+ }
+ }
+}
+
} // end of namespace geom_ws