summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include/auction_runner_gs.h
diff options
context:
space:
mode:
Diffstat (limited to 'geom_matching/wasserstein/include/auction_runner_gs.h')
-rw-r--r--geom_matching/wasserstein/include/auction_runner_gs.h154
1 files changed, 75 insertions, 79 deletions
diff --git a/geom_matching/wasserstein/include/auction_runner_gs.h b/geom_matching/wasserstein/include/auction_runner_gs.h
index ce139f1..fc76987 100644
--- a/geom_matching/wasserstein/include/auction_runner_gs.h
+++ b/geom_matching/wasserstein/include/auction_runner_gs.h
@@ -29,98 +29,94 @@ derivative works thereof, in binary and source code form.
#ifndef AUCTION_RUNNER_GS_H
#define AUCTION_RUNNER_GS_H
+#include <memory>
#include <unordered_set>
+#include "spdlog/spdlog.h"
#include "auction_oracle.h"
-//#define KEEP_UNASSIGNED_ORDERED
-// if this symbol is defined,
-// unassigned bidders are processed in a lexicographic order.
-// See LexicogrCompDiagramPoint comparator.
-
-
-namespace geom_ws {
-
-//using AuctionOracle = AuctionOracleLazyHeapRestricted;
-using AuctionOracle = AuctionOracleKDTreeRestricted;
-
-#ifdef KEEP_UNASSIGNED_ORDERED
-using IdxPointPair = std::pair<size_t, DiagramPoint>;
-
-struct LexicogrCompDiagramPoint {
- bool operator ()(const IdxPointPair& a, const IdxPointPair& b) {
- const auto& p1 = a.second;
- const auto& p2 = b.second;
-
- return ( (not p1.isDiagonal() and p2.isDiagonal()) or
- ( p1.isDiagonal() == p2.isDiagonal() and p1.getRealX() < p2.getRealX() ) or
- ( p1.isDiagonal() == p2.isDiagonal() and p1.getRealX() == p2.getRealX() and p1.getRealY() < p2.getRealY() ) or
- ( p1.isDiagonal() == p2.isDiagonal() and p1.getRealX() == p2.getRealX() and p1.getRealY() == p2.getRealY() and a.first < b.first ) );
- }
-};
-
-using OrderedUnassignedKeeper = std::set<IdxPointPair, LexicogrCompDiagramPoint>;
-#endif
-
-// the two parameters that you can tweak in auction algorithm are:
-// 1. epsilonCommonRatio
-// 2. maxIterNum
+namespace hera {
+namespace ws {
+template<class RealType_ = double, class AuctionOracle_ = AuctionOracleKDTreeRestricted<RealType_>, class PointContainer_ = std::vector<DiagramPoint<RealType_>> > // alternatively: AuctionOracleLazyHeap --- TODO
class AuctionRunnerGS {
public:
- AuctionRunnerGS(const std::vector<DiagramPoint>& A,
- const std::vector<DiagramPoint>& B,
- const double q,
- const double _delta,
- const double _internal_p,
- const double _initialEpsilon,
- const double _epsFactor);
- void setEpsilon(double newVal) { assert(epsilon > 0.0); epsilon = newVal; };
- double getEpsilon() const { return epsilon; }
- double getWassersteinDistance();
- double getWassersteinCost();
- double getRelativeError() const { return relativeError; };
- static constexpr int maxIterNum { 35 }; // maximal number of iterations of epsilon-scaling
-private:
+ using Real = RealType_;
+ using AuctionOracle = AuctionOracle_;
+ using DgmPoint = typename AuctionOracle::DiagramPointR;
+ using IdxValPairR = IdxValPair<Real>;
+ using PointContainer = PointContainer_;
+
+
+ AuctionRunnerGS(const PointContainer& A,
+ const PointContainer& B,
+ const AuctionParams<Real>& params,
+ const std::string& _log_filename_prefix = "");
+
+ void set_epsilon(Real new_val) { assert(epsilon > 0.0); epsilon = new_val; };
+ Real get_epsilon() const { return oracle.get_epsilon(); }
+ Real get_wasserstein_cost();
+ Real get_wasserstein_distance();
+ Real get_relative_error() const { return relative_error; };
+ void enable_logging(const char* log_filename, const size_t _max_unassigned_to_log);
+//private:
// private data
- std::vector<DiagramPoint> bidders, items;
- const size_t numBidders;
- const size_t numItems;
- std::vector<IdxType> itemsToBidders;
- std::vector<IdxType> biddersToItems;
- double wassersteinPower;
- double epsilon;
- double delta;
- double internal_p;
- double initialEpsilon;
- double epsilonCommonRatio; // next epsilon = current epsilon / epsilonCommonRatio
- double weightAdjConst;
- double wassersteinDistance;
- double wassersteinCost;
- double relativeError;
+ PointContainer bidders, items;
+ const size_t num_bidders;
+ const size_t num_items;
+ std::vector<IdxType> items_to_bidders;
+ std::vector<IdxType> bidders_to_items;
+ Real wasserstein_power;
+ Real epsilon;
+ Real delta;
+ Real internal_p;
+ Real initial_epsilon;
+ Real epsilon_common_ratio; // next epsilon = current epsilon / epsilon_common_ratio
+ const int max_num_phases; // maximal number of iterations of epsilon-scaling
+ Real weight_adj_const;
+ Real wasserstein_cost;
+ Real relative_error;
+ int dimension;
// to get the 2 best items
- std::unique_ptr<AuctionOracle> oracle;
-#ifdef KEEP_UNASSIGNED_ORDERED
- OrderedUnassignedKeeper unassignedBidders;
-#else
- std::unordered_set<size_t> unassignedBidders;
-#endif
+ AuctionOracle oracle;
+ std::unordered_set<size_t> unassigned_bidders;
// private methods
- void assignItemToBidder(const IdxType bidderIdx, const IdxType itemsIdx);
- void clearBidTable(void);
- void runAuction(void);
- void runAuctionPhase(void);
- void flushAssignment(void);
+ void assign_item_to_bidder(const IdxType bidder_idx, const IdxType items_idx);
+ void run_auction();
+ void run_auction_phases(const int max_num_phases, const Real _initial_epsilon);
+ void run_auction_phase();
+ void flush_assignment();
+ // return 0, if item_idx is invalid
+ Real get_item_bidder_cost(const size_t item_idx, const size_t bidder_idx, const bool tolerate_invalid_idx = false) const;
// for debug only
- void sanityCheck(void);
- void printDebug(void);
- int countUnhappy(void);
- void printMatching(void);
- double getDistanceToQthPowerInternal(void);
- int numRounds { 0 };
+ void sanity_check();
+ void print_debug();
+ int count_unhappy();
+ void print_matching();
+ Real getDistanceToQthPowerInternal();
+ int num_phase { 0 };
+ int num_rounds { 0 };
+ bool is_distance_computed {false};
+#ifdef LOG_AUCTION
+ bool log_auction { false };
+ std::shared_ptr<spdlog::logger> console_logger;
+ std::shared_ptr<spdlog::logger> plot_logger;
+ std::unordered_set<size_t> unassigned_items;
+ size_t max_unassigned_to_log { 0 };
+ const char* logger_name = "auction_detailed_logger"; // the name in spdlog registry; filename is provided as parameter in enable_logging
+ const Real total_items_persistence;
+ const Real total_bidders_persistence;
+ Real partial_cost;
+ Real unassigned_bidders_persistence;
+ Real unassigned_items_persistence;
+#endif
};
-} // end of namespace geom_ws
+} // ws
+} // hera
+
+
+#include "auction_runner_gs.hpp"
#endif