summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/src/auction_runner_gs.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/src/auction_runner_gs.cpp')
-rw-r--r--geom_matching/wasserstein/src/auction_runner_gs.cpp61
1 files changed, 53 insertions, 8 deletions
diff --git a/geom_matching/wasserstein/src/auction_runner_gs.cpp b/geom_matching/wasserstein/src/auction_runner_gs.cpp
index 4ff495f..bd25442 100644
--- a/geom_matching/wasserstein/src/auction_runner_gs.cpp
+++ b/geom_matching/wasserstein/src/auction_runner_gs.cpp
@@ -27,6 +27,7 @@ derivative works thereof, in binary and source code form.
#include <assert.h>
+#include <stdexcept>
#include <algorithm>
#include <functional>
#include <iterator>
@@ -36,6 +37,10 @@ derivative works thereof, in binary and source code form.
#include "auction_runner_gs.h"
#include "wasserstein.h"
+#ifdef FOR_R_TDA
+#include "Rcpp.h"
+#endif
+
//#define PRINT_DETAILED_TIMING
namespace geom_ws {
@@ -144,12 +149,14 @@ void AuctionRunnerGS::runAuction(void)
double denominator = currentResult - numBidders * oracle->getEpsilon();
currentResult = pow(currentResult, 1.0 / wassersteinPower);
#ifdef PRINT_DETAILED_TIMING
+#ifndef FOR_R_TDA
iterResults.push_back(currentResult);
iterTimes.push_back(hrClock.now());
std::cout << "Iteration " << iterNum << " finished. ";
std::cout << "Current result is " << currentResult << ", epsilon = " << oracle->getEpsilon() << std::endl;
std::cout << "Number of rounds (cumulative): " << numRounds << std::endl;
#endif
+#endif
if ( denominator <= 0 ) {
//std::cout << "Epsilon is too big." << std::endl;
notDone = true;
@@ -157,22 +164,27 @@ void AuctionRunnerGS::runAuction(void)
denominator = pow(denominator, 1.0 / wassersteinPower);
double numerator = currentResult - denominator;
#ifdef PRINT_DETAILED_TIMING
+#ifndef FOR_R_TDA
std::cout << " numerator: " << numerator << " denominator: " << denominator;
std::cout << "; error bound: " << numerator / denominator << std::endl;
#endif
+#endif
// if relative error is greater than delta, continue
notDone = ( numerator / denominator > delta );
}
// decrease epsilon for the next iteration
oracle->setEpsilon( oracle->getEpsilon() / epsilonCommonRatio );
if (iterNum > maxIterNum) {
+#ifndef FOR_R_TDA
std::cerr << "Maximum iteration number exceeded, exiting. Current result is:";
std::cerr << wassersteinDistance << std::endl;
- std::exit(1);
+#endif
+ throw std::runtime_error("Maximum iteration number exceeded");
}
} while ( notDone );
//printMatching();
#ifdef PRINT_DETAILED_TIMING
+#ifndef FOR_R_TDA
for(size_t iterIdx = 0; iterIdx < iterResults.size(); ++iterIdx) {
double trueRelError = ( iterResults.at(iterIdx) - currentResult ) / currentResult;
auto iterCumulativeTime = iterTimes.at(iterIdx) - startMoment;
@@ -183,6 +195,7 @@ void AuctionRunnerGS::runAuction(void)
", iteration time " << iterTime.count() << std::endl;
}
#endif
+#endif
}
void AuctionRunnerGS::runAuctionPhase(void)
@@ -200,15 +213,22 @@ void AuctionRunnerGS::runAuctionPhase(void)
assignItemToBidder(optimalBid.first, bidderIdx);
oracle->setPrice(optimalItemIdx, bidValue);
//printDebug();
+#ifdef FOR_R_TDA
+ if ( numRounds % 10000 == 0 ) {
+ Rcpp::checkUserInterrupt();
+ }
+#endif
} while (not unassignedBidders.empty());
//std::cout << "runAuctionPhase finished" << std::endl;
#ifdef DEBUG_AUCTION
for(size_t bidderIdx = 0; bidderIdx < numBidders; ++bidderIdx) {
if ( biddersToItems[bidderIdx] < 0) {
+#ifndef FOR_R_TDA
std::cerr << "After auction terminated bidder " << bidderIdx;
std::cerr << " has no items assigned" << std::endl;
- throw "Auction did not give a perfect matching";
+#endif
+ throw std::runtime_error("Auction did not give a perfect matching");
}
}
#endif
@@ -225,6 +245,7 @@ double AuctionRunnerGS::getDistanceToQthPowerInternal(void)
auto pB = items[biddersToItems[bIdx]];
result += pow(distLp(pA, pB, internal_p), wassersteinPower);
}
+ wassersteinCost = result;
wassersteinDistance = pow(result, 1.0 / wassersteinPower);
return result;
}
@@ -235,11 +256,20 @@ double AuctionRunnerGS::getWassersteinDistance(void)
return wassersteinDistance;
}
+double AuctionRunnerGS::getWassersteinCost(void)
+{
+ runAuction();
+ return wassersteinCost;
+}
+
+
+
// Debug routines
void AuctionRunnerGS::printDebug(void)
{
#ifdef DEBUG_AUCTION
+#ifndef FOR_R_TDA
sanityCheck();
std::cout << "**********************" << std::endl;
std::cout << "Current assignment:" << std::endl;
@@ -259,6 +289,7 @@ void AuctionRunnerGS::printDebug(void)
}
std::cout << "**********************" << std::endl;
#endif
+#endif
}
@@ -266,13 +297,17 @@ void AuctionRunnerGS::sanityCheck(void)
{
#ifdef DEBUG_AUCTION
if (biddersToItems.size() != numBidders) {
+#ifndef FOR_R_TDA
std::cerr << "Wrong size of biddersToItems, must be " << numBidders << ", is " << biddersToItems.size() << std::endl;
- throw "Wrong size of biddersToItems";
+#endif
+ throw std::runtime_error("Wrong size of biddersToItems");
}
if (itemsToBidders.size() != numBidders) {
+#ifndef FOR_R_TDA
std::cerr << "Wrong size of itemsToBidders, must be " << numBidders << ", is " << itemsToBidders.size() << std::endl;
- throw "Wrong size of itemsToBidders";
+#endif
+ throw std::runtime_error("Wrong size of itemsToBidders");
}
for(size_t bidderIdx = 0; bidderIdx < numBidders; ++bidderIdx) {
@@ -281,18 +316,22 @@ void AuctionRunnerGS::sanityCheck(void)
if ( std::count(biddersToItems.begin(),
biddersToItems.end(),
biddersToItems[bidderIdx]) > 1 ) {
+#ifndef FOR_R_TDA
std::cerr << "Item " << biddersToItems[bidderIdx];
std::cerr << " appears in biddersToItems more than once" << std::endl;
- throw "Duplicate in biddersToItems";
+#endif
+ throw std::runtime_error("Duplicate in biddersToItems");
}
if (itemsToBidders.at(biddersToItems[bidderIdx]) != static_cast<int>(bidderIdx)) {
+#ifndef FOR_R_TDA
std::cerr << "Inconsitency: bidderIdx = " << bidderIdx;
std::cerr << ", itemIdx in biddersToItems = ";
std::cerr << biddersToItems[bidderIdx];
std::cerr << ", bidderIdx in itemsToBidders = ";
std::cerr << itemsToBidders[biddersToItems[bidderIdx]] << std::endl;
- throw "inconsistent mapping";
+#endif
+ throw std::runtime_error("inconsistent mapping");
}
}
}
@@ -304,18 +343,22 @@ void AuctionRunnerGS::sanityCheck(void)
if ( std::count(itemsToBidders.begin(),
itemsToBidders.end(),
itemsToBidders[itemIdx]) > 1 ) {
+#ifndef FOR_R_TDA
std::cerr << "Bidder " << itemsToBidders[itemIdx];
std::cerr << " appears in itemsToBidders more than once" << std::endl;
- throw "Duplicate in itemsToBidders";
+#endif
+ throw std::runtime_error("Duplicate in itemsToBidders");
}
// check for consistency
if (biddersToItems.at(itemsToBidders[itemIdx]) != static_cast<int>(itemIdx)) {
+#ifndef FOR_R_TDA
std::cerr << "Inconsitency: itemIdx = " << itemIdx;
std::cerr << ", bidderIdx in itemsToBidders = ";
std::cerr << itemsToBidders[itemIdx];
std::cerr << ", itemIdx in biddersToItems= ";
std::cerr << biddersToItems[itemsToBidders[itemIdx]] << std::endl;
- throw "inconsistent mapping";
+#endif
+ throw std::runtime_error("inconsistent mapping");
}
}
}
@@ -325,6 +368,7 @@ void AuctionRunnerGS::sanityCheck(void)
void AuctionRunnerGS::printMatching(void)
{
//#ifdef DEBUG_AUCTION
+#ifndef FOR_R_TDA
sanityCheck();
for(size_t bIdx = 0; bIdx < biddersToItems.size(); ++bIdx) {
if (biddersToItems[bIdx] >= 0) {
@@ -335,6 +379,7 @@ void AuctionRunnerGS::printMatching(void)
assert(false);
}
}
+#endif
//#endif
}