From 7b850b8ee43fb7f8a0b2a1565ed01102d40b0a14 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Mon, 5 Sep 2016 13:32:05 +0200 Subject: Technical changes for R integration Avoid including iostream (R complains about that). All output protected by preprocessor directive (R checker should not see an instance of std::cout << in your code). Also added getWassersteinCost to be in line with the Dionysus implementation used in TDA. --- .../wasserstein/src/auction_runner_gs.cpp | 61 +++++++++++++++++++--- 1 file changed, 53 insertions(+), 8 deletions(-) (limited to 'geom_matching/wasserstein/src/auction_runner_gs.cpp') diff --git a/geom_matching/wasserstein/src/auction_runner_gs.cpp b/geom_matching/wasserstein/src/auction_runner_gs.cpp index 4ff495f..bd25442 100644 --- a/geom_matching/wasserstein/src/auction_runner_gs.cpp +++ b/geom_matching/wasserstein/src/auction_runner_gs.cpp @@ -27,6 +27,7 @@ derivative works thereof, in binary and source code form. #include +#include #include #include #include @@ -36,6 +37,10 @@ derivative works thereof, in binary and source code form. #include "auction_runner_gs.h" #include "wasserstein.h" +#ifdef FOR_R_TDA +#include "Rcpp.h" +#endif + //#define PRINT_DETAILED_TIMING namespace geom_ws { @@ -144,11 +149,13 @@ void AuctionRunnerGS::runAuction(void) double denominator = currentResult - numBidders * oracle->getEpsilon(); currentResult = pow(currentResult, 1.0 / wassersteinPower); #ifdef PRINT_DETAILED_TIMING +#ifndef FOR_R_TDA iterResults.push_back(currentResult); iterTimes.push_back(hrClock.now()); std::cout << "Iteration " << iterNum << " finished. "; std::cout << "Current result is " << currentResult << ", epsilon = " << oracle->getEpsilon() << std::endl; std::cout << "Number of rounds (cumulative): " << numRounds << std::endl; +#endif #endif if ( denominator <= 0 ) { //std::cout << "Epsilon is too big." << std::endl; @@ -157,8 +164,10 @@ void AuctionRunnerGS::runAuction(void) denominator = pow(denominator, 1.0 / wassersteinPower); double numerator = currentResult - denominator; #ifdef PRINT_DETAILED_TIMING +#ifndef FOR_R_TDA std::cout << " numerator: " << numerator << " denominator: " << denominator; std::cout << "; error bound: " << numerator / denominator << std::endl; +#endif #endif // if relative error is greater than delta, continue notDone = ( numerator / denominator > delta ); @@ -166,13 +175,16 @@ void AuctionRunnerGS::runAuction(void) // decrease epsilon for the next iteration oracle->setEpsilon( oracle->getEpsilon() / epsilonCommonRatio ); if (iterNum > maxIterNum) { +#ifndef FOR_R_TDA std::cerr << "Maximum iteration number exceeded, exiting. Current result is:"; std::cerr << wassersteinDistance << std::endl; - std::exit(1); +#endif + throw std::runtime_error("Maximum iteration number exceeded"); } } while ( notDone ); //printMatching(); #ifdef PRINT_DETAILED_TIMING +#ifndef FOR_R_TDA for(size_t iterIdx = 0; iterIdx < iterResults.size(); ++iterIdx) { double trueRelError = ( iterResults.at(iterIdx) - currentResult ) / currentResult; auto iterCumulativeTime = iterTimes.at(iterIdx) - startMoment; @@ -183,6 +195,7 @@ void AuctionRunnerGS::runAuction(void) ", iteration time " << iterTime.count() << std::endl; } #endif +#endif } void AuctionRunnerGS::runAuctionPhase(void) @@ -200,15 +213,22 @@ void AuctionRunnerGS::runAuctionPhase(void) assignItemToBidder(optimalBid.first, bidderIdx); oracle->setPrice(optimalItemIdx, bidValue); //printDebug(); +#ifdef FOR_R_TDA + if ( numRounds % 10000 == 0 ) { + Rcpp::checkUserInterrupt(); + } +#endif } while (not unassignedBidders.empty()); //std::cout << "runAuctionPhase finished" << std::endl; #ifdef DEBUG_AUCTION for(size_t bidderIdx = 0; bidderIdx < numBidders; ++bidderIdx) { if ( biddersToItems[bidderIdx] < 0) { +#ifndef FOR_R_TDA std::cerr << "After auction terminated bidder " << bidderIdx; std::cerr << " has no items assigned" << std::endl; - throw "Auction did not give a perfect matching"; +#endif + throw std::runtime_error("Auction did not give a perfect matching"); } } #endif @@ -225,6 +245,7 @@ double AuctionRunnerGS::getDistanceToQthPowerInternal(void) auto pB = items[biddersToItems[bIdx]]; result += pow(distLp(pA, pB, internal_p), wassersteinPower); } + wassersteinCost = result; wassersteinDistance = pow(result, 1.0 / wassersteinPower); return result; } @@ -235,11 +256,20 @@ double AuctionRunnerGS::getWassersteinDistance(void) return wassersteinDistance; } +double AuctionRunnerGS::getWassersteinCost(void) +{ + runAuction(); + return wassersteinCost; +} + + + // Debug routines void AuctionRunnerGS::printDebug(void) { #ifdef DEBUG_AUCTION +#ifndef FOR_R_TDA sanityCheck(); std::cout << "**********************" << std::endl; std::cout << "Current assignment:" << std::endl; @@ -259,6 +289,7 @@ void AuctionRunnerGS::printDebug(void) } std::cout << "**********************" << std::endl; #endif +#endif } @@ -266,13 +297,17 @@ void AuctionRunnerGS::sanityCheck(void) { #ifdef DEBUG_AUCTION if (biddersToItems.size() != numBidders) { +#ifndef FOR_R_TDA std::cerr << "Wrong size of biddersToItems, must be " << numBidders << ", is " << biddersToItems.size() << std::endl; - throw "Wrong size of biddersToItems"; +#endif + throw std::runtime_error("Wrong size of biddersToItems"); } if (itemsToBidders.size() != numBidders) { +#ifndef FOR_R_TDA std::cerr << "Wrong size of itemsToBidders, must be " << numBidders << ", is " << itemsToBidders.size() << std::endl; - throw "Wrong size of itemsToBidders"; +#endif + throw std::runtime_error("Wrong size of itemsToBidders"); } for(size_t bidderIdx = 0; bidderIdx < numBidders; ++bidderIdx) { @@ -281,18 +316,22 @@ void AuctionRunnerGS::sanityCheck(void) if ( std::count(biddersToItems.begin(), biddersToItems.end(), biddersToItems[bidderIdx]) > 1 ) { +#ifndef FOR_R_TDA std::cerr << "Item " << biddersToItems[bidderIdx]; std::cerr << " appears in biddersToItems more than once" << std::endl; - throw "Duplicate in biddersToItems"; +#endif + throw std::runtime_error("Duplicate in biddersToItems"); } if (itemsToBidders.at(biddersToItems[bidderIdx]) != static_cast(bidderIdx)) { +#ifndef FOR_R_TDA std::cerr << "Inconsitency: bidderIdx = " << bidderIdx; std::cerr << ", itemIdx in biddersToItems = "; std::cerr << biddersToItems[bidderIdx]; std::cerr << ", bidderIdx in itemsToBidders = "; std::cerr << itemsToBidders[biddersToItems[bidderIdx]] << std::endl; - throw "inconsistent mapping"; +#endif + throw std::runtime_error("inconsistent mapping"); } } } @@ -304,18 +343,22 @@ void AuctionRunnerGS::sanityCheck(void) if ( std::count(itemsToBidders.begin(), itemsToBidders.end(), itemsToBidders[itemIdx]) > 1 ) { +#ifndef FOR_R_TDA std::cerr << "Bidder " << itemsToBidders[itemIdx]; std::cerr << " appears in itemsToBidders more than once" << std::endl; - throw "Duplicate in itemsToBidders"; +#endif + throw std::runtime_error("Duplicate in itemsToBidders"); } // check for consistency if (biddersToItems.at(itemsToBidders[itemIdx]) != static_cast(itemIdx)) { +#ifndef FOR_R_TDA std::cerr << "Inconsitency: itemIdx = " << itemIdx; std::cerr << ", bidderIdx in itemsToBidders = "; std::cerr << itemsToBidders[itemIdx]; std::cerr << ", itemIdx in biddersToItems= "; std::cerr << biddersToItems[itemsToBidders[itemIdx]] << std::endl; - throw "inconsistent mapping"; +#endif + throw std::runtime_error("inconsistent mapping"); } } } @@ -325,6 +368,7 @@ void AuctionRunnerGS::sanityCheck(void) void AuctionRunnerGS::printMatching(void) { //#ifdef DEBUG_AUCTION +#ifndef FOR_R_TDA sanityCheck(); for(size_t bIdx = 0; bIdx < biddersToItems.size(); ++bIdx) { if (biddersToItems[bIdx] >= 0) { @@ -335,6 +379,7 @@ void AuctionRunnerGS::printMatching(void) assert(false); } } +#endif //#endif } -- cgit v1.2.3