From 0cc35ad04f9c2997014d7cf62a12f697e79fb534 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Sat, 20 Jan 2018 19:11:29 +0100 Subject: Major rewrite, templatized version --- .../wasserstein/include/wasserstein_pure_geom.hpp | 87 ++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 geom_matching/wasserstein/include/wasserstein_pure_geom.hpp (limited to 'geom_matching/wasserstein/include/wasserstein_pure_geom.hpp') diff --git a/geom_matching/wasserstein/include/wasserstein_pure_geom.hpp b/geom_matching/wasserstein/include/wasserstein_pure_geom.hpp new file mode 100644 index 0000000..2a57599 --- /dev/null +++ b/geom_matching/wasserstein/include/wasserstein_pure_geom.hpp @@ -0,0 +1,87 @@ +#ifndef WASSERSTEIN_PURE_GEOM_HPP +#define WASSERSTEIN_PURE_GEOM_HPP + +#define WASSERSTEIN_PURE_GEOM + + +#include "diagram_reader.h" +#include "auction_oracle_kdtree_pure_geom.h" +#include "auction_runner_gs.h" +#include "auction_runner_jac.h" + +namespace hera +{ +namespace ws +{ + + template + using DynamicTraits = typename hera::ws::dnn::DynamicPointTraits; + + template + using DynamicPoint = typename hera::ws::dnn::DynamicPointTraits::PointType; + + template + using DynamicPointVector = typename hera::ws::dnn::DynamicPointVector; + + template + using AuctionRunnerGSR = typename hera::ws::AuctionRunnerGS, hera::ws::dnn::DynamicPointVector>; + + template + using AuctionRunnerJacR = typename hera::ws::AuctionRunnerJac, hera::ws::dnn::DynamicPointVector>; + + +double wasserstein_cost(const DynamicPointVector& set_A, const DynamicPointVector& set_B, const AuctionParams& params) +{ + if (params.wasserstein_power < 1.0) { + throw std::runtime_error("Bad q in Wasserstein " + std::to_string(params.wasserstein_power)); + } + + if (params.delta < 0.0) { + throw std::runtime_error("Bad delta in Wasserstein " + std::to_string(params.delta)); + } + + if (params.initial_epsilon < 0.0) { + throw std::runtime_error("Bad initial epsilon in Wasserstein" + std::to_string(params.initial_epsilon)); + } + + if (params.epsilon_common_ratio < 0.0) { + throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(params.epsilon_common_ratio)); + } + + if (set_A.size() != set_B.size()) { + throw std::runtime_error("Different cardinalities of point clouds: " + std::to_string(set_A.size()) + " != " + std::to_string(set_B.size())); + } + + DynamicTraits traits(params.dim); + + DynamicPointVector set_A_copy(set_A); + DynamicPointVector set_B_copy(set_B); + + // set point id to the index in vector + for(size_t i = 0; i < set_A.size(); ++i) { + traits.id(set_A_copy[i]) = i; + traits.id(set_B_copy[i]) = i; + } + + if (params.max_bids_per_round == 1) { + hera::ws::AuctionRunnerGSR auction(set_A_copy, set_B_copy, params); + auction.run_auction(); + return auction.get_wasserstein_cost(); + } else { + hera::ws::AuctionRunnerJacR auction(set_A_copy, set_B_copy, params); + auction.run_auction(); + return auction.get_wasserstein_cost(); + } + +} + +double wasserstein_dist(const DynamicPointVector& set_A, const DynamicPointVector& set_B, const AuctionParams& params) +{ + return std::pow(wasserstein_cost(set_A, set_B, params), 1.0 / params.wasserstein_power); +} + +} // ws +} // hera + + +#endif -- cgit v1.2.3