diff options
author | Arnur Nigmetov <nigmetov@tugraz.at> | 2019-02-06 22:45:18 +0100 |
---|---|---|
committer | Arnur Nigmetov <nigmetov@tugraz.at> | 2019-02-06 22:45:18 +0100 |
commit | ba25f264b5d309efcf77a6b72d1b784ae97f741f (patch) | |
tree | fb536fc0d0e9f241e6c7b37b45a94e31d25f1a8f | |
parent | 657f73321f04d5d1c4cec8085ec43a73633b96af (diff) |
Switch to opts, tolerate max-iterations-exceeded.
1. Use opts.h for command-line parsing in wasserstein_dist.
2. Real relative error at the end of auction is stored in params,
params is passed by reference.
3. If -e command line option is given to wasserstein_dist,
relative error will be printed.
4. If -t option is given to wasserstein_dist,
no exception will be thrown, if maximum number of iterations is
exceeded.
5. Run wasserstein_dist -h to see all options.
8 files changed, 411 insertions, 28 deletions
diff --git a/geom_matching/wasserstein/example/wasserstein_dist.cpp b/geom_matching/wasserstein/example/wasserstein_dist.cpp index fcbc641..25e1f68 100644 --- a/geom_matching/wasserstein/example/wasserstein_dist.cpp +++ b/geom_matching/wasserstein/example/wasserstein_dist.cpp @@ -31,6 +31,8 @@ derivative works thereof, in binary and source code form. #include <iomanip> #include <vector> +#include "opts/opts.h" + //#define LOG_AUCTION //#include "auction_runner_fr.h" @@ -47,23 +49,54 @@ int main(int argc, char* argv[]) PairVector diagramA, diagramB; hera::AuctionParams<double> params; + params.max_num_phases = 800; + + opts::Options ops(argc, argv); + ops >> opts::Option('q', "degree", params.wasserstein_power, "Wasserstein degree") + >> opts::Option('d', "error", params.delta, "Relative error") + >> opts::Option('p', "internal-p", params.internal_p, "Internal norm") + >> opts::Option("initial-epsilon", params.initial_epsilon, "Initial epsilon") + >> opts::Option("epsilon-factor", params.epsilon_common_ratio, "Epsilon factor") + >> opts::Option("max-bids-per-round", params.max_bids_per_round, "Maximal number of bids per round") + >> opts::Option('m', "max-rounds", params.max_num_phases, "Maximal number of iterations"); + + + bool print_relative_error = ops >> opts::Present('e', "--print-error", "Print real relative error"); + + params.tolerate_max_iter_exceeded = ops >> opts::Present('t', "tolerate", "Suppress max-iterations-exceeded error and print the best result."); + + std::string dgm_fname_1, dgm_fname_2; + bool dgm_1_given = (ops >> opts::PosOption(dgm_fname_1)); + bool dgm_2_given = (ops >> opts::PosOption(dgm_fname_2)); - if (argc < 3 ) { - std::cerr << "Usage: " << argv[0] << " file1 file2 [wasserstein_degree] [relative_error] [internal norm] [initial epsilon] [epsilon_factor] [max_bids_per_round] [gamma_threshold][log_filename_prefix]. By default power is 1.0, relative error is 0.01, internal norm is l_infinity, initall epsilon is chosen automatically, epsilon factor is 5.0, Jacobi variant is used (max bids per round is maximal), gamma_threshold = 0.0." << std::endl; + //std::cout << "q = " << params.wasserstein_power << ", delta = " << params.delta << ", p = " << params.internal_p << ", max_round = " << params.max_num_phases << std::endl; + //std::cout << "print relative error: " << print_relative_error << std::endl; + //std::cout << "dgm1: " << dgm_fname_1 << std::endl; + //std::cout << "dgm2: " << dgm_fname_2 << std::endl; + + if (not dgm_1_given or not dgm_2_given) { + std::cerr << "Usage: " << argv[0] << " file1 file2 " << std::endl; + std::cerr << "compute Wasserstein distance between persistence diagrams in file1 and file2.\n"; + std::cerr << ops << std::endl; return 1; } - if (!hera::read_diagram_point_set<double, PairVector>(argv[1], diagramA)) { + if (ops >> opts::Present('h', "help", "show help message")) { + std::cout << "Usage: " << argv[0] << " file1 file2 " << std::endl; + std::cout << "compute Wasserstein distance between persistence diagrams in file1 and file2.\n"; + std::cout << ops << std::endl; + } + + if (!hera::read_diagram_point_set<double, PairVector>(dgm_fname_1, diagramA)) { std::exit(1); } - if (!hera::read_diagram_point_set(argv[2], diagramB)) { + if (!hera::read_diagram_point_set(dgm_fname_2, diagramB)) { std::exit(1); } - params.wasserstein_power = (4 <= argc) ? atof(argv[3]) : 1.0; if (params.wasserstein_power < 1.0) { - std::cerr << "The third argument (wasserstein_degree) was \"" << argv[3] << "\", must be a number >= 1.0. Cannot proceed. " << std::endl; + std::cerr << "Wasserstein_degree was \"" << params.wasserstein_power << "\", must be a number >= 1.0. Cannot proceed. " << std::endl; std::exit(1); } @@ -72,50 +105,40 @@ int main(int argc, char* argv[]) } //default relative error: 1% - params.delta = (5 <= argc) ? atof(argv[4]) : 0.01; if ( params.delta <= 0.0) { - std::cerr << "The 4th argument (relative error) was \"" << argv[4] << "\", must be a number > 0.0. Cannot proceed. " << std::endl; + std::cerr << "relative error was \"" << params.delta << "\", must be a number > 0.0. Cannot proceed. " << std::endl; std::exit(1); } // default for internal metric is l_infinity - params.internal_p = ( 6 <= argc ) ? atof(argv[5]) : hera::get_infinity<double>(); if (std::isinf(params.internal_p)) { params.internal_p = hera::get_infinity<double>(); } if (not hera::is_p_valid_norm<double>(params.internal_p)) { - std::cerr << "The 5th argument (internal norm) was \"" << argv[5] << "\", must be a number >= 1.0 or inf. Cannot proceed. " << std::endl; + std::cerr << "internal-p was \"" << params.internal_p << "\", must be a number >= 1.0 or inf. Cannot proceed. " << std::endl; std::exit(1); } // if you want to specify initial value for epsilon and the factor // for epsilon-scaling - params.initial_epsilon= ( 7 <= argc ) ? atof(argv[6]) : 0.0 ; - if (params.initial_epsilon < 0.0) { - std::cerr << "The 6th argument (initial epsilon) was \"" << argv[6] << "\", must be a non-negative number. Cannot proceed." << std::endl; + std::cerr << "initial-epsilon was \"" << params.initial_epsilon << "\", must be a non-negative number. Cannot proceed." << std::endl; std::exit(1); } - params.epsilon_common_ratio = ( 8 <= argc ) ? atof(argv[7]) : 0.0 ; if (params.epsilon_common_ratio <= 1.0 and params.epsilon_common_ratio != 0.0) { - std::cerr << "The 7th argument (epsilon factor) was \"" << argv[7] << "\", must be a number greater than 1. Cannot proceed." << std::endl; + std::cerr << "The 7th argument (epsilon factor) was \"" << params.epsilon_common_ratio << "\", must be a number greater than 1. Cannot proceed." << std::endl; std::exit(1); } - - params.max_bids_per_round = ( 9 <= argc ) ? atoi(argv[8]) : 0; if (params.max_bids_per_round == 0) - params.max_bids_per_round = std::numeric_limits<size_t>::max(); - + params.max_bids_per_round = std::numeric_limits<decltype(params.max_bids_per_round)>::max(); - params.gamma_threshold = (10 <= argc) ? atof(argv[9]) : 0.0; std::string log_filename_prefix = ( 11 <= argc ) ? argv[10] : ""; - params.max_num_phases = 800; #ifdef LOG_AUCTION spdlog::set_level(spdlog::level::info); @@ -124,6 +147,8 @@ int main(int argc, char* argv[]) double res = hera::wasserstein_dist(diagramA, diagramB, params, log_filename_prefix); std::cout << std::setprecision(15) << res << std::endl; + if (print_relative_error) + std::cout << "Relative error: " << params.final_relative_error << std::endl; return 0; diff --git a/geom_matching/wasserstein/example/wasserstein_dist_dipha.cpp b/geom_matching/wasserstein/example/wasserstein_dist_dipha.cpp index cd8c61a..2ed9c2c 100644 --- a/geom_matching/wasserstein/example/wasserstein_dist_dipha.cpp +++ b/geom_matching/wasserstein/example/wasserstein_dist_dipha.cpp @@ -111,7 +111,7 @@ int main(int argc, char* argv[]) params.max_bids_per_round = ( 10 <= argc ) ? atoi(argv[9]) : 0; if (params.max_bids_per_round == 0) - params.max_bids_per_round = std::numeric_limits<size_t>::max(); + params.max_bids_per_round = std::numeric_limits<decltype(params.max_bids_per_round)>::max(); params.gamma_threshold = (11 <= argc) ? atof(argv[10]) : 0.0; diff --git a/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp b/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp index 2f9718e..ab7ff4f 100644 --- a/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp +++ b/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp @@ -154,7 +154,7 @@ int main(int argc, char* argv[]) params.max_bids_per_round = ( 9 <= argc ) ? atoi(argv[8]) : 0; if (params.max_bids_per_round == 0) - params.max_bids_per_round = std::numeric_limits<size_t>::max(); + params.max_bids_per_round = std::numeric_limits<decltype(params.max_bids_per_round)>::max(); params.gamma_threshold = (10 <= argc) ? atof(argv[9]) : 0.0; diff --git a/geom_matching/wasserstein/include/auction_runner_gs.h b/geom_matching/wasserstein/include/auction_runner_gs.h index fc76987..f8b5a8b 100644 --- a/geom_matching/wasserstein/include/auction_runner_gs.h +++ b/geom_matching/wasserstein/include/auction_runner_gs.h @@ -73,6 +73,7 @@ public: Real initial_epsilon; Real epsilon_common_ratio; // next epsilon = current epsilon / epsilon_common_ratio const int max_num_phases; // maximal number of iterations of epsilon-scaling + bool tolerate_max_iter_exceeded; Real weight_adj_const; Real wasserstein_cost; Real relative_error; diff --git a/geom_matching/wasserstein/include/auction_runner_gs.hpp b/geom_matching/wasserstein/include/auction_runner_gs.hpp index 141cb2c..4ef94db 100644 --- a/geom_matching/wasserstein/include/auction_runner_gs.hpp +++ b/geom_matching/wasserstein/include/auction_runner_gs.hpp @@ -68,6 +68,7 @@ AuctionRunnerGS<R, AO, PC>::AuctionRunnerGS(const PC& A, initial_epsilon(params.initial_epsilon), epsilon_common_ratio(params.epsilon_common_ratio == 0.0 ? 5.0 : params.epsilon_common_ratio), max_num_phases(params.max_num_phases), + tolerate_max_iter_exceeded(params.tolerate_max_iter_exceeded), dimension(params.dim), oracle(bidders, items, params) #ifdef LOG_AUCTION @@ -294,7 +295,7 @@ void AuctionRunnerGS<R, AO, PC>::run_auction() double init_eps = ( initial_epsilon > 0.0 ) ? initial_epsilon : oracle.max_val_ / 4.0 ; run_auction_phases(max_num_phases, init_eps); is_distance_computed = true; - if (relative_error > delta) { + if (relative_error > delta and not tolerate_max_iter_exceeded) { #ifndef FOR_R_TDA std::cerr << "Maximum iteration number exceeded, exiting. Current result is: "; std::cerr << pow(wasserstein_cost, 1.0/wasserstein_power) << std::endl; diff --git a/geom_matching/wasserstein/include/basic_defs_ws.h b/geom_matching/wasserstein/include/basic_defs_ws.h index 1c5928f..1712ccf 100644 --- a/geom_matching/wasserstein/include/basic_defs_ws.h +++ b/geom_matching/wasserstein/include/basic_defs_ws.h @@ -88,8 +88,10 @@ struct AuctionParams Real epsilon_common_ratio { 5.0 }; Real gamma_threshold { 0.0 }; // for experiments, not in use now int max_num_phases { std::numeric_limits<decltype(max_num_phases)>::max() }; - size_t max_bids_per_round { 1 }; // imitate Gauss-Seidel is default behaviour + int max_bids_per_round { 1 }; // imitate Gauss-Seidel is default behaviour unsigned int dim { 2 }; // for pure geometric version only; ignored in persistence diagrams + Real final_relative_error; // out parameter - after auction terminates, contains the real relative error + bool tolerate_max_iter_exceeded { false }; // whether auction should throw an exception on max. iterations exceeded }; namespace ws diff --git a/geom_matching/wasserstein/include/opts/opts.h b/geom_matching/wasserstein/include/opts/opts.h new file mode 100755 index 0000000..74e788b --- /dev/null +++ b/geom_matching/wasserstein/include/opts/opts.h @@ -0,0 +1,353 @@ +/** + * Author: Dmitriy Morozov <dmitriy@mrzv.org> + * The interface is heavily influenced by GetOptPP (https://code.google.com/p/getoptpp/). + */ + +#ifndef OPTS_OPTS_H +#define OPTS_OPTS_H + +#include <iostream> +#include <sstream> +#include <string> +#include <list> +#include <vector> + +namespace opts { + +// Converters +template<class T> +struct Converter +{ + Converter() {} + static + T convert(const std::string& val) { std::istringstream iss(val); T res; iss >> res; return res; } +}; + +// Type +template<class T> +struct Traits +{ + static std::string type_string() { return "UNKNOWN TYPE"; } +}; + +template<> +struct Traits<int> +{ + static std::string type_string() { return "INT"; } +}; + +template<> +struct Traits<short int> +{ + static std::string type_string() { return "SHORT INT"; } +}; + +template<> +struct Traits<unsigned> +{ + static std::string type_string() { return "UNSIGNED INT"; } +}; + +template<> +struct Traits<short unsigned> +{ + static std::string type_string() { return "SHORT UNSIGNED INT"; } +}; + +template<> +struct Traits<float> +{ + static std::string type_string() { return "FLOAT"; } +}; + +template<> +struct Traits<double> +{ + static std::string type_string() { return "DOUBLE"; } +}; + +template<> +struct Traits<std::string> +{ + static std::string type_string() { return "STRING"; } +}; + + +struct BasicOption +{ + BasicOption(char s_, + std::string l_, + std::string default_, + std::string type_, + std::string help_): + s(s_), l(l_), d(default_), t(type_), help(help_) {} + + int long_size() const { return l.size() + 1 + t.size(); } + + void output(std::ostream& out, int max_long) const + { + out << " "; + if (s) + out << '-' << s << ", "; + else + out << " "; + + out << "--" << l << ' '; + + if (!t.empty()) + out << t; + + for (int i = long_size(); i < max_long; ++i) + out << ' '; + + out << " " << help; + + if (!d.empty()) + { + out << " [default: " << d << "]"; + } + out << '\n'; + } + + char s; + std::string l; + std::string d; + std::string t; + std::string help; +}; + +// Option +template<class T> +struct OptionContainer: public BasicOption +{ + OptionContainer(char s_, + const std::string& l_, + T& var_, + const std::string& help_, + const std::string& type_ = Traits<T>::type_string()): + BasicOption(s_, l_, default_value(var_), type_, help_), + var(&var_) {} + + static + std::string default_value(const T& def) + { + std::ostringstream oss; + oss << def; + return oss.str(); + } + + void parse(std::list<std::string>& args) const + { + std::string short_opt = "-"; short_opt += s; + std::string long_opt = "--" + l; + for (std::list<std::string>::iterator cur = args.begin(); cur != args.end(); ++cur) + { + if (*cur == short_opt || *cur == long_opt) + { + cur = args.erase(cur); + if (cur != args.end()) + { + *var = Converter<T>::convert(*cur); + cur = args.erase(cur); + break; // finds first occurrence + } + else + break; // if the last option's value is missing, it remains default + + } + } + } + + T* var; +}; + +template<class T> +struct OptionContainer< std::vector<T> >: public BasicOption +{ + OptionContainer(char s_, + const std::string& l_, + std::vector<T>& var_, + const std::string& help_, + const std::string& type_ = "SEQUENCE"): + BasicOption(s_, l_, default_value(var_), type_, help_), + var(&var_) { } + + static + std::string default_value(const std::vector<T>& def) + { + std::ostringstream oss; + oss << "("; + if (def.size()) + oss << def[0]; + for (int i = 1; i < def.size(); ++i) + oss << ", " << def[i]; + oss << ")"; + return oss.str(); + } + + void parse(std::list<std::string>& args) const + { + std::string short_opt = "-"; short_opt += s; + std::string long_opt = "--" + l; + for (std::list<std::string>::iterator cur = args.begin(); cur != args.end(); ++cur) + { + if (*cur == short_opt || *cur == long_opt) + { + cur = args.erase(cur); + if (cur != args.end()) + { + var->push_back(Converter<T>::convert(*cur)); + cur = args.erase(cur); + } + --cur; + } + } + } + + std::vector<T>* var; +}; + + +template<class T> +OptionContainer<T> +Option(char s, const std::string& l, T& var, const std::string& help) { return OptionContainer<T>(s, l, var, help); } + +template<class T> +OptionContainer<T> +Option(char s, const std::string& l, T& var, + const std::string& type, const std::string& help) { return OptionContainer<T>(s, l, var, help, type); } + +template<class T> +OptionContainer<T> +Option(const std::string& l, T& var, const std::string& help) { return OptionContainer<T>(0, l, var, help); } + +template<class T> +OptionContainer<T> +Option(const std::string& l, T& var, + const std::string& type, const std::string& help) { return OptionContainer<T>(0, l, var, help, type); } + +// Present +struct PresentContainer: public BasicOption +{ + PresentContainer(char s, const std::string& l, const std::string& help): + BasicOption(s,l,"","",help) {} +}; + +inline +PresentContainer +Present(char s, const std::string& l, const std::string& help) { return PresentContainer(s, l, help); } + +inline +PresentContainer +Present(const std::string& l, const std::string& help) { return PresentContainer(0, l, help); } + +// PosOption +template<class T> +struct PosOptionContainer +{ + PosOptionContainer(T& var_): + var(&var_) {} + + bool parse(std::list<std::string>& args) const + { + if (args.empty()) + return false; + + *var = Converter<T>::convert(args.front()); + args.pop_front(); + return true; + } + + T* var; +}; + +template<class T> +PosOptionContainer<T> +PosOption(T& var) { return PosOptionContainer<T>(var); } + + +// Options +struct Options +{ + Options(int argc_, char** argv_): + args(argv_ + 1, argv_ + argc_), + failed(false) {} + + template<class T> + Options& operator>>(const OptionContainer<T>& oc); + bool operator>>(const PresentContainer& pc); + template<class T> + Options& operator>>(const PosOptionContainer<T>& poc); + + operator bool() { return !failed; } + + + friend + std::ostream& + operator<<(std::ostream& out, const Options& ops) + { + int max_long = 0; + for (std::list<BasicOption>::const_iterator cur = ops.options.begin(); + cur != ops.options.end(); + ++cur) + { + int cur_long = cur->long_size(); + if (cur_long > max_long) + max_long = cur_long; + } + + out << "Options:\n"; + for (std::list<BasicOption>::const_iterator cur = ops.options.begin(); + cur != ops.options.end(); + ++cur) + cur->output(out, max_long); + + return out; + } + + + private: + std::list<std::string> args; + std::list<BasicOption> options; + bool failed; +}; + +template<class T> +Options& +Options::operator>>(const OptionContainer<T>& oc) +{ + options.push_back(oc); + oc.parse(args); + return *this; +} + +inline +bool +Options::operator>>(const PresentContainer& pc) +{ + options.push_back(pc); + + for(std::list<std::string>::iterator cur = args.begin(); cur != args.end(); ++cur) + { + std::string short_opt = "-"; short_opt += pc.s; + std::string long_opt = "--" + pc.l; + if (*cur == short_opt || *cur == long_opt) + { + args.erase(cur); + return true; + } + } + return false; +} + +template<class T> +Options& +Options::operator>>(const PosOptionContainer<T>& poc) +{ + failed = !poc.parse(args); + return *this; +} + +} + +#endif diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h index a24bada..35d0bf6 100644 --- a/geom_matching/wasserstein/include/wasserstein.h +++ b/geom_matching/wasserstein/include/wasserstein.h @@ -212,7 +212,7 @@ namespace ws template<class RealType> inline RealType wasserstein_cost_vec(const std::vector<DiagramPoint<RealType>>& A, const std::vector<DiagramPoint<RealType>>& B, - const AuctionParams<RealType>& params, + AuctionParams<RealType>& params, const std::string& _log_filename_prefix) { if (params.wasserstein_power < 1.0) { @@ -237,6 +237,7 @@ namespace ws AuctionRunnerGS<RealType> auction(A, B, params, _log_filename_prefix); auction.run_auction(); result = auction.get_wasserstein_cost(); + params.final_relative_error = auction.get_relative_error(); return result; } @@ -248,7 +249,7 @@ template<class PairContainer> inline typename DiagramTraits<PairContainer>::RealType wasserstein_cost(const PairContainer& A, const PairContainer& B, - const AuctionParams< typename DiagramTraits<PairContainer>::RealType >& params, + AuctionParams< typename DiagramTraits<PairContainer>::RealType >& params, const std::string& _log_filename_prefix = "") { using Traits = DiagramTraits<PairContainer>; @@ -335,7 +336,7 @@ template<class PairContainer> inline typename DiagramTraits<PairContainer>::RealType wasserstein_dist(const PairContainer& A, const PairContainer& B, - const AuctionParams<typename DiagramTraits<PairContainer>::RealType> params, + AuctionParams<typename DiagramTraits<PairContainer>::RealType>& params, const std::string& _log_filename_prefix = "") { using Real = typename DiagramTraits<PairContainer>::RealType; |