From 0cc35ad04f9c2997014d7cf62a12f697e79fb534 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Sat, 20 Jan 2018 19:11:29 +0100 Subject: Major rewrite, templatized version --- .../wasserstein/include/auction_runner_gs.h | 154 ++++++++++----------- 1 file changed, 75 insertions(+), 79 deletions(-) (limited to 'geom_matching/wasserstein/include/auction_runner_gs.h') diff --git a/geom_matching/wasserstein/include/auction_runner_gs.h b/geom_matching/wasserstein/include/auction_runner_gs.h index ce139f1..fc76987 100644 --- a/geom_matching/wasserstein/include/auction_runner_gs.h +++ b/geom_matching/wasserstein/include/auction_runner_gs.h @@ -29,98 +29,94 @@ derivative works thereof, in binary and source code form. #ifndef AUCTION_RUNNER_GS_H #define AUCTION_RUNNER_GS_H +#include #include +#include "spdlog/spdlog.h" #include "auction_oracle.h" -//#define KEEP_UNASSIGNED_ORDERED -// if this symbol is defined, -// unassigned bidders are processed in a lexicographic order. -// See LexicogrCompDiagramPoint comparator. - - -namespace geom_ws { - -//using AuctionOracle = AuctionOracleLazyHeapRestricted; -using AuctionOracle = AuctionOracleKDTreeRestricted; - -#ifdef KEEP_UNASSIGNED_ORDERED -using IdxPointPair = std::pair; - -struct LexicogrCompDiagramPoint { - bool operator ()(const IdxPointPair& a, const IdxPointPair& b) { - const auto& p1 = a.second; - const auto& p2 = b.second; - - return ( (not p1.isDiagonal() and p2.isDiagonal()) or - ( p1.isDiagonal() == p2.isDiagonal() and p1.getRealX() < p2.getRealX() ) or - ( p1.isDiagonal() == p2.isDiagonal() and p1.getRealX() == p2.getRealX() and p1.getRealY() < p2.getRealY() ) or - ( p1.isDiagonal() == p2.isDiagonal() and p1.getRealX() == p2.getRealX() and p1.getRealY() == p2.getRealY() and a.first < b.first ) ); - } -}; - -using OrderedUnassignedKeeper = std::set; -#endif - -// the two parameters that you can tweak in auction algorithm are: -// 1. epsilonCommonRatio -// 2. maxIterNum +namespace hera { +namespace ws { +template, class PointContainer_ = std::vector> > // alternatively: AuctionOracleLazyHeap --- TODO class AuctionRunnerGS { public: - AuctionRunnerGS(const std::vector& A, - const std::vector& B, - const double q, - const double _delta, - const double _internal_p, - const double _initialEpsilon, - const double _epsFactor); - void setEpsilon(double newVal) { assert(epsilon > 0.0); epsilon = newVal; }; - double getEpsilon() const { return epsilon; } - double getWassersteinDistance(); - double getWassersteinCost(); - double getRelativeError() const { return relativeError; }; - static constexpr int maxIterNum { 35 }; // maximal number of iterations of epsilon-scaling -private: + using Real = RealType_; + using AuctionOracle = AuctionOracle_; + using DgmPoint = typename AuctionOracle::DiagramPointR; + using IdxValPairR = IdxValPair; + using PointContainer = PointContainer_; + + + AuctionRunnerGS(const PointContainer& A, + const PointContainer& B, + const AuctionParams& params, + const std::string& _log_filename_prefix = ""); + + void set_epsilon(Real new_val) { assert(epsilon > 0.0); epsilon = new_val; }; + Real get_epsilon() const { return oracle.get_epsilon(); } + Real get_wasserstein_cost(); + Real get_wasserstein_distance(); + Real get_relative_error() const { return relative_error; }; + void enable_logging(const char* log_filename, const size_t _max_unassigned_to_log); +//private: // private data - std::vector bidders, items; - const size_t numBidders; - const size_t numItems; - std::vector itemsToBidders; - std::vector biddersToItems; - double wassersteinPower; - double epsilon; - double delta; - double internal_p; - double initialEpsilon; - double epsilonCommonRatio; // next epsilon = current epsilon / epsilonCommonRatio - double weightAdjConst; - double wassersteinDistance; - double wassersteinCost; - double relativeError; + PointContainer bidders, items; + const size_t num_bidders; + const size_t num_items; + std::vector items_to_bidders; + std::vector bidders_to_items; + Real wasserstein_power; + Real epsilon; + Real delta; + Real internal_p; + Real initial_epsilon; + Real epsilon_common_ratio; // next epsilon = current epsilon / epsilon_common_ratio + const int max_num_phases; // maximal number of iterations of epsilon-scaling + Real weight_adj_const; + Real wasserstein_cost; + Real relative_error; + int dimension; // to get the 2 best items - std::unique_ptr oracle; -#ifdef KEEP_UNASSIGNED_ORDERED - OrderedUnassignedKeeper unassignedBidders; -#else - std::unordered_set unassignedBidders; -#endif + AuctionOracle oracle; + std::unordered_set unassigned_bidders; // private methods - void assignItemToBidder(const IdxType bidderIdx, const IdxType itemsIdx); - void clearBidTable(void); - void runAuction(void); - void runAuctionPhase(void); - void flushAssignment(void); + void assign_item_to_bidder(const IdxType bidder_idx, const IdxType items_idx); + void run_auction(); + void run_auction_phases(const int max_num_phases, const Real _initial_epsilon); + void run_auction_phase(); + void flush_assignment(); + // return 0, if item_idx is invalid + Real get_item_bidder_cost(const size_t item_idx, const size_t bidder_idx, const bool tolerate_invalid_idx = false) const; // for debug only - void sanityCheck(void); - void printDebug(void); - int countUnhappy(void); - void printMatching(void); - double getDistanceToQthPowerInternal(void); - int numRounds { 0 }; + void sanity_check(); + void print_debug(); + int count_unhappy(); + void print_matching(); + Real getDistanceToQthPowerInternal(); + int num_phase { 0 }; + int num_rounds { 0 }; + bool is_distance_computed {false}; +#ifdef LOG_AUCTION + bool log_auction { false }; + std::shared_ptr console_logger; + std::shared_ptr plot_logger; + std::unordered_set unassigned_items; + size_t max_unassigned_to_log { 0 }; + const char* logger_name = "auction_detailed_logger"; // the name in spdlog registry; filename is provided as parameter in enable_logging + const Real total_items_persistence; + const Real total_bidders_persistence; + Real partial_cost; + Real unassigned_bidders_persistence; + Real unassigned_items_persistence; +#endif }; -} // end of namespace geom_ws +} // ws +} // hera + + +#include "auction_runner_gs.hpp" #endif -- cgit v1.2.3