summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArnur Nigmetov <a.nigmetov@gmail.com>2017-04-05 21:58:27 +0200
committerArnur Nigmetov <a.nigmetov@gmail.com>2017-04-05 21:58:27 +0200
commit5cda650aa878a9b5929c65eb83f431e117d39811 (patch)
tree21504e49bdf716e8671b9bc1e18df4ba0c66be80
parent1607071fcd9d473eae295693fc97bee8c50d6a11 (diff)
Remove duplicate points for Wasserstein-1
-rw-r--r--geom_matching/wasserstein/example/wasserstein_dist.cpp8
-rw-r--r--geom_matching/wasserstein/include/wasserstein.h7
-rw-r--r--geom_matching/wasserstein/src/auction_runner_gs.cpp2
-rw-r--r--geom_matching/wasserstein/src/wasserstein.cpp44
4 files changed, 53 insertions, 8 deletions
diff --git a/geom_matching/wasserstein/example/wasserstein_dist.cpp b/geom_matching/wasserstein/example/wasserstein_dist.cpp
index c699282..a2ed234 100644
--- a/geom_matching/wasserstein/example/wasserstein_dist.cpp
+++ b/geom_matching/wasserstein/example/wasserstein_dist.cpp
@@ -39,11 +39,9 @@ derivative works thereof, in binary and source code form.
// any container of pairs of doubles can be used,
// we use vector in this example.
-typedef std::vector<std::pair<double, double>> PairVector;
-
int main(int argc, char* argv[])
{
- PairVector diagramA, diagramB;
+ geom_ws::PairVector diagramA, diagramB;
if (argc < 3 ) {
std::cerr << "Usage: " << argv[0] << " file1 file2 [wasserstein_degree] [relative_error] [internal norm] [output_actual_error]. By default power is 1.0, relative error is 0.01, internal norm is l_infinity, actual relative error is not printed." << std::endl;
@@ -64,6 +62,10 @@ int main(int argc, char* argv[])
std::exit(1);
}
+ if (wasserPower == 1.0) {
+ geom_ws::removeDuplicates(diagramA, diagramB);
+ }
+
//default relative error: 1%
double delta = (5 <= argc) ? atof(argv[4]) : 0.01;
if ( delta <= 0.0) {
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h
index fcc9e74..1d8f35b 100644
--- a/geom_matching/wasserstein/include/wasserstein.h
+++ b/geom_matching/wasserstein/include/wasserstein.h
@@ -40,6 +40,8 @@ derivative works thereof, in binary and source code form.
namespace geom_ws {
+using PairVector = std::vector<std::pair<double, double>>;
+
// get Wasserstein distance between two persistence diagrams
double wassersteinDistVec(const std::vector<DiagramPoint>& A,
const std::vector<DiagramPoint>& B,
@@ -138,8 +140,9 @@ double wassersteinCost(PairContainer& A, PairContainer& B, const double q, const
// fill in result with points from file fname
// return false if file can't be opened
// or error occurred while reading
-bool readDiagramPointSet(const char* fname, std::vector<std::pair<double, double>>& result);
-bool readDiagramPointSet(const std::string& fname, std::vector<std::pair<double, double>>& result);
+bool readDiagramPointSet(const char* fname, PairVector& result);
+bool readDiagramPointSet(const std::string& fname, PairVector& result);
+void removeDuplicates(PairVector& dgmA, PairVector& dgmB);
} // end of namespace geom_ws
diff --git a/geom_matching/wasserstein/src/auction_runner_gs.cpp b/geom_matching/wasserstein/src/auction_runner_gs.cpp
index 5865325..e568813 100644
--- a/geom_matching/wasserstein/src/auction_runner_gs.cpp
+++ b/geom_matching/wasserstein/src/auction_runner_gs.cpp
@@ -245,7 +245,7 @@ double AuctionRunnerGS::getDistanceToQthPowerInternal(void)
auto pA = bidders[bIdx];
assert( 0 <= biddersToItems[bIdx] and biddersToItems[bIdx] < static_cast<int>(items.size()) );
auto pB = items[biddersToItems[bIdx]];
- std::cout << "pA = " << pA << ", pB = " << pB << ", pow(distLp(pA, pB, internal_p), wassersteinPower) = " << pow(distLp(pA, pB, internal_p), wassersteinPower) << ", dist = " << distLp(pA, pB, internal_p) << std::endl;
+ //std::cout << "pA = " << pA << ", pB = " << pB << ", pow(distLp(pA, pB, internal_p), wassersteinPower) = " << pow(distLp(pA, pB, internal_p), wassersteinPower) << ", dist = " << distLp(pA, pB, internal_p) << std::endl;
result += pow(distLp(pA, pB, internal_p), wassersteinPower);
}
wassersteinCost = result;
diff --git a/geom_matching/wasserstein/src/wasserstein.cpp b/geom_matching/wasserstein/src/wasserstein.cpp
index 2c536ed..8776b5f 100644
--- a/geom_matching/wasserstein/src/wasserstein.cpp
+++ b/geom_matching/wasserstein/src/wasserstein.cpp
@@ -74,6 +74,7 @@ double wassersteinDistVec(const std::vector<DiagramPoint>& A,
throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(_epsFactor));
}
+
#ifdef GAUSS_SEIDEL_AUCTION
AuctionRunnerGS auction(A, B, q, delta, _internal_p, _initialEpsilon, _epsFactor);
#else
@@ -124,12 +125,12 @@ double wassersteinCostVec(const std::vector<DiagramPoint>& A,
return result;
}
-bool readDiagramPointSet(const std::string& fname, std::vector<std::pair<double, double>>& result)
+bool readDiagramPointSet(const std::string& fname, PairVector& result)
{
return readDiagramPointSet(fname.c_str(), result);
}
-bool readDiagramPointSet(const char* fname, std::vector<std::pair<double, double>>& result)
+bool readDiagramPointSet(const char* fname, PairVector& result)
{
size_t lineNumber { 0 };
result.clear();
@@ -173,4 +174,43 @@ bool readDiagramPointSet(const char* fname, std::vector<std::pair<double, double
return true;
}
+
+void removeDuplicates(PairVector& dgmA, PairVector& dgmB)
+{
+ std::map<std::pair<double, double>, int> mapA, mapB;
+ // copy points to maps
+ for(const auto& ptA : dgmA) {
+ mapA[ptA]++;
+ }
+ for(const auto& ptB : dgmB) {
+ mapB[ptB]++;
+ }
+ // clear vectors
+ dgmA.clear();
+ dgmB.clear();
+ // remove duplicates from maps
+ for(auto& pointMultiplicityPair : mapA) {
+ auto iterB = mapB.find(pointMultiplicityPair.first);
+ if (iterB != mapB.end()) {
+ int duplicateMultiplicity = std::min(pointMultiplicityPair.second, iterB->second);
+ pointMultiplicityPair.second -= duplicateMultiplicity;
+ iterB->second -= duplicateMultiplicity;
+ }
+ }
+ // copy points back to vectors
+ for(const auto& pointMultiplicityPairA : mapA) {
+ assert( pointMultiplicityPairA.second >= 0);
+ for(int i = 0; i < pointMultiplicityPairA.second; ++i) {
+ dgmA.push_back(pointMultiplicityPairA.first);
+ }
+ }
+
+ for(const auto& pointMultiplicityPairB : mapB) {
+ assert( pointMultiplicityPairB.second >= 0);
+ for(int i = 0; i < pointMultiplicityPairB.second; ++i) {
+ dgmB.push_back(pointMultiplicityPairB.first);
+ }
+ }
+}
+
} // end of namespace geom_ws