#ifndef WASSERSTEIN_PURE_GEOM_HPP #define WASSERSTEIN_PURE_GEOM_HPP #define WASSERSTEIN_PURE_GEOM #include "diagram_reader.h" #include "auction_oracle_kdtree_pure_geom.h" #include "auction_runner_gs.h" #include "auction_runner_jac.h" namespace hera { namespace ws { template using DynamicTraits = typename hera::ws::dnn::DynamicPointTraits; template using DynamicPoint = typename hera::ws::dnn::DynamicPointTraits::PointType; template using DynamicPointVector = typename hera::ws::dnn::DynamicPointVector; template using AuctionRunnerGSR = typename hera::ws::AuctionRunnerGS, hera::ws::dnn::DynamicPointVector>; template using AuctionRunnerJacR = typename hera::ws::AuctionRunnerJac, hera::ws::dnn::DynamicPointVector>; double wasserstein_cost(const DynamicPointVector& set_A, const DynamicPointVector& set_B, const AuctionParams& params) { if (params.wasserstein_power < 1.0) { throw std::runtime_error("Bad q in Wasserstein " + std::to_string(params.wasserstein_power)); } if (params.delta < 0.0) { throw std::runtime_error("Bad delta in Wasserstein " + std::to_string(params.delta)); } if (params.initial_epsilon < 0.0) { throw std::runtime_error("Bad initial epsilon in Wasserstein" + std::to_string(params.initial_epsilon)); } if (params.epsilon_common_ratio < 0.0) { throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(params.epsilon_common_ratio)); } if (set_A.size() != set_B.size()) { throw std::runtime_error("Different cardinalities of point clouds: " + std::to_string(set_A.size()) + " != " + std::to_string(set_B.size())); } DynamicTraits traits(params.dim); DynamicPointVector set_A_copy(set_A); DynamicPointVector set_B_copy(set_B); // set point id to the index in vector for(size_t i = 0; i < set_A.size(); ++i) { traits.id(set_A_copy[i]) = i; traits.id(set_B_copy[i]) = i; } if (params.max_bids_per_round == 1) { hera::ws::AuctionRunnerGSR auction(set_A_copy, set_B_copy, params); auction.run_auction(); return auction.get_wasserstein_cost(); } else { hera::ws::AuctionRunnerJacR auction(set_A_copy, set_B_copy, params); auction.run_auction(); return auction.get_wasserstein_cost(); } } double wasserstein_dist(const DynamicPointVector& set_A, const DynamicPointVector& set_B, const AuctionParams& params) { return std::pow(wasserstein_cost(set_A, set_B, params), 1.0 / params.wasserstein_power); } } // ws } // hera #endif