diff options
Diffstat (limited to 'geom_matching/wasserstein/include')
5 files changed, 363 insertions, 5 deletions
diff --git a/geom_matching/wasserstein/include/auction_runner_gs.h b/geom_matching/wasserstein/include/auction_runner_gs.h index fc76987..f8b5a8b 100644 --- a/geom_matching/wasserstein/include/auction_runner_gs.h +++ b/geom_matching/wasserstein/include/auction_runner_gs.h @@ -73,6 +73,7 @@ public: Real initial_epsilon; Real epsilon_common_ratio; // next epsilon = current epsilon / epsilon_common_ratio const int max_num_phases; // maximal number of iterations of epsilon-scaling + bool tolerate_max_iter_exceeded; Real weight_adj_const; Real wasserstein_cost; Real relative_error; diff --git a/geom_matching/wasserstein/include/auction_runner_gs.hpp b/geom_matching/wasserstein/include/auction_runner_gs.hpp index 141cb2c..4ef94db 100644 --- a/geom_matching/wasserstein/include/auction_runner_gs.hpp +++ b/geom_matching/wasserstein/include/auction_runner_gs.hpp @@ -68,6 +68,7 @@ AuctionRunnerGS<R, AO, PC>::AuctionRunnerGS(const PC& A, initial_epsilon(params.initial_epsilon), epsilon_common_ratio(params.epsilon_common_ratio == 0.0 ? 5.0 : params.epsilon_common_ratio), max_num_phases(params.max_num_phases), + tolerate_max_iter_exceeded(params.tolerate_max_iter_exceeded), dimension(params.dim), oracle(bidders, items, params) #ifdef LOG_AUCTION @@ -294,7 +295,7 @@ void AuctionRunnerGS<R, AO, PC>::run_auction() double init_eps = ( initial_epsilon > 0.0 ) ? initial_epsilon : oracle.max_val_ / 4.0 ; run_auction_phases(max_num_phases, init_eps); is_distance_computed = true; - if (relative_error > delta) { + if (relative_error > delta and not tolerate_max_iter_exceeded) { #ifndef FOR_R_TDA std::cerr << "Maximum iteration number exceeded, exiting. Current result is: "; std::cerr << pow(wasserstein_cost, 1.0/wasserstein_power) << std::endl; diff --git a/geom_matching/wasserstein/include/basic_defs_ws.h b/geom_matching/wasserstein/include/basic_defs_ws.h index 1c5928f..1712ccf 100644 --- a/geom_matching/wasserstein/include/basic_defs_ws.h +++ b/geom_matching/wasserstein/include/basic_defs_ws.h @@ -88,8 +88,10 @@ struct AuctionParams Real epsilon_common_ratio { 5.0 }; Real gamma_threshold { 0.0 }; // for experiments, not in use now int max_num_phases { std::numeric_limits<decltype(max_num_phases)>::max() }; - size_t max_bids_per_round { 1 }; // imitate Gauss-Seidel is default behaviour + int max_bids_per_round { 1 }; // imitate Gauss-Seidel is default behaviour unsigned int dim { 2 }; // for pure geometric version only; ignored in persistence diagrams + Real final_relative_error; // out parameter - after auction terminates, contains the real relative error + bool tolerate_max_iter_exceeded { false }; // whether auction should throw an exception on max. iterations exceeded }; namespace ws diff --git a/geom_matching/wasserstein/include/opts/opts.h b/geom_matching/wasserstein/include/opts/opts.h new file mode 100755 index 0000000..74e788b --- /dev/null +++ b/geom_matching/wasserstein/include/opts/opts.h @@ -0,0 +1,353 @@ +/** + * Author: Dmitriy Morozov <dmitriy@mrzv.org> + * The interface is heavily influenced by GetOptPP (https://code.google.com/p/getoptpp/). + */ + +#ifndef OPTS_OPTS_H +#define OPTS_OPTS_H + +#include <iostream> +#include <sstream> +#include <string> +#include <list> +#include <vector> + +namespace opts { + +// Converters +template<class T> +struct Converter +{ + Converter() {} + static + T convert(const std::string& val) { std::istringstream iss(val); T res; iss >> res; return res; } +}; + +// Type +template<class T> +struct Traits +{ + static std::string type_string() { return "UNKNOWN TYPE"; } +}; + +template<> +struct Traits<int> +{ + static std::string type_string() { return "INT"; } +}; + +template<> +struct Traits<short int> +{ + static std::string type_string() { return "SHORT INT"; } +}; + +template<> +struct Traits<unsigned> +{ + static std::string type_string() { return "UNSIGNED INT"; } +}; + +template<> +struct Traits<short unsigned> +{ + static std::string type_string() { return "SHORT UNSIGNED INT"; } +}; + +template<> +struct Traits<float> +{ + static std::string type_string() { return "FLOAT"; } +}; + +template<> +struct Traits<double> +{ + static std::string type_string() { return "DOUBLE"; } +}; + +template<> +struct Traits<std::string> +{ + static std::string type_string() { return "STRING"; } +}; + + +struct BasicOption +{ + BasicOption(char s_, + std::string l_, + std::string default_, + std::string type_, + std::string help_): + s(s_), l(l_), d(default_), t(type_), help(help_) {} + + int long_size() const { return l.size() + 1 + t.size(); } + + void output(std::ostream& out, int max_long) const + { + out << " "; + if (s) + out << '-' << s << ", "; + else + out << " "; + + out << "--" << l << ' '; + + if (!t.empty()) + out << t; + + for (int i = long_size(); i < max_long; ++i) + out << ' '; + + out << " " << help; + + if (!d.empty()) + { + out << " [default: " << d << "]"; + } + out << '\n'; + } + + char s; + std::string l; + std::string d; + std::string t; + std::string help; +}; + +// Option +template<class T> +struct OptionContainer: public BasicOption +{ + OptionContainer(char s_, + const std::string& l_, + T& var_, + const std::string& help_, + const std::string& type_ = Traits<T>::type_string()): + BasicOption(s_, l_, default_value(var_), type_, help_), + var(&var_) {} + + static + std::string default_value(const T& def) + { + std::ostringstream oss; + oss << def; + return oss.str(); + } + + void parse(std::list<std::string>& args) const + { + std::string short_opt = "-"; short_opt += s; + std::string long_opt = "--" + l; + for (std::list<std::string>::iterator cur = args.begin(); cur != args.end(); ++cur) + { + if (*cur == short_opt || *cur == long_opt) + { + cur = args.erase(cur); + if (cur != args.end()) + { + *var = Converter<T>::convert(*cur); + cur = args.erase(cur); + break; // finds first occurrence + } + else + break; // if the last option's value is missing, it remains default + + } + } + } + + T* var; +}; + +template<class T> +struct OptionContainer< std::vector<T> >: public BasicOption +{ + OptionContainer(char s_, + const std::string& l_, + std::vector<T>& var_, + const std::string& help_, + const std::string& type_ = "SEQUENCE"): + BasicOption(s_, l_, default_value(var_), type_, help_), + var(&var_) { } + + static + std::string default_value(const std::vector<T>& def) + { + std::ostringstream oss; + oss << "("; + if (def.size()) + oss << def[0]; + for (int i = 1; i < def.size(); ++i) + oss << ", " << def[i]; + oss << ")"; + return oss.str(); + } + + void parse(std::list<std::string>& args) const + { + std::string short_opt = "-"; short_opt += s; + std::string long_opt = "--" + l; + for (std::list<std::string>::iterator cur = args.begin(); cur != args.end(); ++cur) + { + if (*cur == short_opt || *cur == long_opt) + { + cur = args.erase(cur); + if (cur != args.end()) + { + var->push_back(Converter<T>::convert(*cur)); + cur = args.erase(cur); + } + --cur; + } + } + } + + std::vector<T>* var; +}; + + +template<class T> +OptionContainer<T> +Option(char s, const std::string& l, T& var, const std::string& help) { return OptionContainer<T>(s, l, var, help); } + +template<class T> +OptionContainer<T> +Option(char s, const std::string& l, T& var, + const std::string& type, const std::string& help) { return OptionContainer<T>(s, l, var, help, type); } + +template<class T> +OptionContainer<T> +Option(const std::string& l, T& var, const std::string& help) { return OptionContainer<T>(0, l, var, help); } + +template<class T> +OptionContainer<T> +Option(const std::string& l, T& var, + const std::string& type, const std::string& help) { return OptionContainer<T>(0, l, var, help, type); } + +// Present +struct PresentContainer: public BasicOption +{ + PresentContainer(char s, const std::string& l, const std::string& help): + BasicOption(s,l,"","",help) {} +}; + +inline +PresentContainer +Present(char s, const std::string& l, const std::string& help) { return PresentContainer(s, l, help); } + +inline +PresentContainer +Present(const std::string& l, const std::string& help) { return PresentContainer(0, l, help); } + +// PosOption +template<class T> +struct PosOptionContainer +{ + PosOptionContainer(T& var_): + var(&var_) {} + + bool parse(std::list<std::string>& args) const + { + if (args.empty()) + return false; + + *var = Converter<T>::convert(args.front()); + args.pop_front(); + return true; + } + + T* var; +}; + +template<class T> +PosOptionContainer<T> +PosOption(T& var) { return PosOptionContainer<T>(var); } + + +// Options +struct Options +{ + Options(int argc_, char** argv_): + args(argv_ + 1, argv_ + argc_), + failed(false) {} + + template<class T> + Options& operator>>(const OptionContainer<T>& oc); + bool operator>>(const PresentContainer& pc); + template<class T> + Options& operator>>(const PosOptionContainer<T>& poc); + + operator bool() { return !failed; } + + + friend + std::ostream& + operator<<(std::ostream& out, const Options& ops) + { + int max_long = 0; + for (std::list<BasicOption>::const_iterator cur = ops.options.begin(); + cur != ops.options.end(); + ++cur) + { + int cur_long = cur->long_size(); + if (cur_long > max_long) + max_long = cur_long; + } + + out << "Options:\n"; + for (std::list<BasicOption>::const_iterator cur = ops.options.begin(); + cur != ops.options.end(); + ++cur) + cur->output(out, max_long); + + return out; + } + + + private: + std::list<std::string> args; + std::list<BasicOption> options; + bool failed; +}; + +template<class T> +Options& +Options::operator>>(const OptionContainer<T>& oc) +{ + options.push_back(oc); + oc.parse(args); + return *this; +} + +inline +bool +Options::operator>>(const PresentContainer& pc) +{ + options.push_back(pc); + + for(std::list<std::string>::iterator cur = args.begin(); cur != args.end(); ++cur) + { + std::string short_opt = "-"; short_opt += pc.s; + std::string long_opt = "--" + pc.l; + if (*cur == short_opt || *cur == long_opt) + { + args.erase(cur); + return true; + } + } + return false; +} + +template<class T> +Options& +Options::operator>>(const PosOptionContainer<T>& poc) +{ + failed = !poc.parse(args); + return *this; +} + +} + +#endif diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h index a24bada..35d0bf6 100644 --- a/geom_matching/wasserstein/include/wasserstein.h +++ b/geom_matching/wasserstein/include/wasserstein.h @@ -212,7 +212,7 @@ namespace ws template<class RealType> inline RealType wasserstein_cost_vec(const std::vector<DiagramPoint<RealType>>& A, const std::vector<DiagramPoint<RealType>>& B, - const AuctionParams<RealType>& params, + AuctionParams<RealType>& params, const std::string& _log_filename_prefix) { if (params.wasserstein_power < 1.0) { @@ -237,6 +237,7 @@ namespace ws AuctionRunnerGS<RealType> auction(A, B, params, _log_filename_prefix); auction.run_auction(); result = auction.get_wasserstein_cost(); + params.final_relative_error = auction.get_relative_error(); return result; } @@ -248,7 +249,7 @@ template<class PairContainer> inline typename DiagramTraits<PairContainer>::RealType wasserstein_cost(const PairContainer& A, const PairContainer& B, - const AuctionParams< typename DiagramTraits<PairContainer>::RealType >& params, + AuctionParams< typename DiagramTraits<PairContainer>::RealType >& params, const std::string& _log_filename_prefix = "") { using Traits = DiagramTraits<PairContainer>; @@ -335,7 +336,7 @@ template<class PairContainer> inline typename DiagramTraits<PairContainer>::RealType wasserstein_dist(const PairContainer& A, const PairContainer& B, - const AuctionParams<typename DiagramTraits<PairContainer>::RealType> params, + AuctionParams<typename DiagramTraits<PairContainer>::RealType>& params, const std::string& _log_filename_prefix = "") { using Real = typename DiagramTraits<PairContainer>::RealType; |