summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArnur Nigmetov <nigmetov@tugraz.at>2019-02-06 22:45:18 +0100
committerArnur Nigmetov <nigmetov@tugraz.at>2019-02-06 22:45:18 +0100
commitba25f264b5d309efcf77a6b72d1b784ae97f741f (patch)
treefb536fc0d0e9f241e6c7b37b45a94e31d25f1a8f
parent657f73321f04d5d1c4cec8085ec43a73633b96af (diff)
Switch to opts, tolerate max-iterations-exceeded.
1. Use opts.h for command-line parsing in wasserstein_dist. 2. Real relative error at the end of auction is stored in params, params is passed by reference. 3. If -e command line option is given to wasserstein_dist, relative error will be printed. 4. If -t option is given to wasserstein_dist, no exception will be thrown, if maximum number of iterations is exceeded. 5. Run wasserstein_dist -h to see all options.
-rw-r--r--geom_matching/wasserstein/example/wasserstein_dist.cpp67
-rw-r--r--geom_matching/wasserstein/example/wasserstein_dist_dipha.cpp2
-rw-r--r--geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp2
-rw-r--r--geom_matching/wasserstein/include/auction_runner_gs.h1
-rw-r--r--geom_matching/wasserstein/include/auction_runner_gs.hpp3
-rw-r--r--geom_matching/wasserstein/include/basic_defs_ws.h4
-rwxr-xr-xgeom_matching/wasserstein/include/opts/opts.h353
-rw-r--r--geom_matching/wasserstein/include/wasserstein.h7
8 files changed, 411 insertions, 28 deletions
diff --git a/geom_matching/wasserstein/example/wasserstein_dist.cpp b/geom_matching/wasserstein/example/wasserstein_dist.cpp
index fcbc641..25e1f68 100644
--- a/geom_matching/wasserstein/example/wasserstein_dist.cpp
+++ b/geom_matching/wasserstein/example/wasserstein_dist.cpp
@@ -31,6 +31,8 @@ derivative works thereof, in binary and source code form.
#include <iomanip>
#include <vector>
+#include "opts/opts.h"
+
//#define LOG_AUCTION
//#include "auction_runner_fr.h"
@@ -47,23 +49,54 @@ int main(int argc, char* argv[])
PairVector diagramA, diagramB;
hera::AuctionParams<double> params;
+ params.max_num_phases = 800;
+
+ opts::Options ops(argc, argv);
+ ops >> opts::Option('q', "degree", params.wasserstein_power, "Wasserstein degree")
+ >> opts::Option('d', "error", params.delta, "Relative error")
+ >> opts::Option('p', "internal-p", params.internal_p, "Internal norm")
+ >> opts::Option("initial-epsilon", params.initial_epsilon, "Initial epsilon")
+ >> opts::Option("epsilon-factor", params.epsilon_common_ratio, "Epsilon factor")
+ >> opts::Option("max-bids-per-round", params.max_bids_per_round, "Maximal number of bids per round")
+ >> opts::Option('m', "max-rounds", params.max_num_phases, "Maximal number of iterations");
+
+
+ bool print_relative_error = ops >> opts::Present('e', "--print-error", "Print real relative error");
+
+ params.tolerate_max_iter_exceeded = ops >> opts::Present('t', "tolerate", "Suppress max-iterations-exceeded error and print the best result.");
+
+ std::string dgm_fname_1, dgm_fname_2;
+ bool dgm_1_given = (ops >> opts::PosOption(dgm_fname_1));
+ bool dgm_2_given = (ops >> opts::PosOption(dgm_fname_2));
- if (argc < 3 ) {
- std::cerr << "Usage: " << argv[0] << " file1 file2 [wasserstein_degree] [relative_error] [internal norm] [initial epsilon] [epsilon_factor] [max_bids_per_round] [gamma_threshold][log_filename_prefix]. By default power is 1.0, relative error is 0.01, internal norm is l_infinity, initall epsilon is chosen automatically, epsilon factor is 5.0, Jacobi variant is used (max bids per round is maximal), gamma_threshold = 0.0." << std::endl;
+ //std::cout << "q = " << params.wasserstein_power << ", delta = " << params.delta << ", p = " << params.internal_p << ", max_round = " << params.max_num_phases << std::endl;
+ //std::cout << "print relative error: " << print_relative_error << std::endl;
+ //std::cout << "dgm1: " << dgm_fname_1 << std::endl;
+ //std::cout << "dgm2: " << dgm_fname_2 << std::endl;
+
+ if (not dgm_1_given or not dgm_2_given) {
+ std::cerr << "Usage: " << argv[0] << " file1 file2 " << std::endl;
+ std::cerr << "compute Wasserstein distance between persistence diagrams in file1 and file2.\n";
+ std::cerr << ops << std::endl;
return 1;
}
- if (!hera::read_diagram_point_set<double, PairVector>(argv[1], diagramA)) {
+ if (ops >> opts::Present('h', "help", "show help message")) {
+ std::cout << "Usage: " << argv[0] << " file1 file2 " << std::endl;
+ std::cout << "compute Wasserstein distance between persistence diagrams in file1 and file2.\n";
+ std::cout << ops << std::endl;
+ }
+
+ if (!hera::read_diagram_point_set<double, PairVector>(dgm_fname_1, diagramA)) {
std::exit(1);
}
- if (!hera::read_diagram_point_set(argv[2], diagramB)) {
+ if (!hera::read_diagram_point_set(dgm_fname_2, diagramB)) {
std::exit(1);
}
- params.wasserstein_power = (4 <= argc) ? atof(argv[3]) : 1.0;
if (params.wasserstein_power < 1.0) {
- std::cerr << "The third argument (wasserstein_degree) was \"" << argv[3] << "\", must be a number >= 1.0. Cannot proceed. " << std::endl;
+ std::cerr << "Wasserstein_degree was \"" << params.wasserstein_power << "\", must be a number >= 1.0. Cannot proceed. " << std::endl;
std::exit(1);
}
@@ -72,50 +105,40 @@ int main(int argc, char* argv[])
}
//default relative error: 1%
- params.delta = (5 <= argc) ? atof(argv[4]) : 0.01;
if ( params.delta <= 0.0) {
- std::cerr << "The 4th argument (relative error) was \"" << argv[4] << "\", must be a number > 0.0. Cannot proceed. " << std::endl;
+ std::cerr << "relative error was \"" << params.delta << "\", must be a number > 0.0. Cannot proceed. " << std::endl;
std::exit(1);
}
// default for internal metric is l_infinity
- params.internal_p = ( 6 <= argc ) ? atof(argv[5]) : hera::get_infinity<double>();
if (std::isinf(params.internal_p)) {
params.internal_p = hera::get_infinity<double>();
}
if (not hera::is_p_valid_norm<double>(params.internal_p)) {
- std::cerr << "The 5th argument (internal norm) was \"" << argv[5] << "\", must be a number >= 1.0 or inf. Cannot proceed. " << std::endl;
+ std::cerr << "internal-p was \"" << params.internal_p << "\", must be a number >= 1.0 or inf. Cannot proceed. " << std::endl;
std::exit(1);
}
// if you want to specify initial value for epsilon and the factor
// for epsilon-scaling
- params.initial_epsilon= ( 7 <= argc ) ? atof(argv[6]) : 0.0 ;
-
if (params.initial_epsilon < 0.0) {
- std::cerr << "The 6th argument (initial epsilon) was \"" << argv[6] << "\", must be a non-negative number. Cannot proceed." << std::endl;
+ std::cerr << "initial-epsilon was \"" << params.initial_epsilon << "\", must be a non-negative number. Cannot proceed." << std::endl;
std::exit(1);
}
- params.epsilon_common_ratio = ( 8 <= argc ) ? atof(argv[7]) : 0.0 ;
if (params.epsilon_common_ratio <= 1.0 and params.epsilon_common_ratio != 0.0) {
- std::cerr << "The 7th argument (epsilon factor) was \"" << argv[7] << "\", must be a number greater than 1. Cannot proceed." << std::endl;
+ std::cerr << "The 7th argument (epsilon factor) was \"" << params.epsilon_common_ratio << "\", must be a number greater than 1. Cannot proceed." << std::endl;
std::exit(1);
}
-
- params.max_bids_per_round = ( 9 <= argc ) ? atoi(argv[8]) : 0;
if (params.max_bids_per_round == 0)
- params.max_bids_per_round = std::numeric_limits<size_t>::max();
-
+ params.max_bids_per_round = std::numeric_limits<decltype(params.max_bids_per_round)>::max();
- params.gamma_threshold = (10 <= argc) ? atof(argv[9]) : 0.0;
std::string log_filename_prefix = ( 11 <= argc ) ? argv[10] : "";
- params.max_num_phases = 800;
#ifdef LOG_AUCTION
spdlog::set_level(spdlog::level::info);
@@ -124,6 +147,8 @@ int main(int argc, char* argv[])
double res = hera::wasserstein_dist(diagramA, diagramB, params, log_filename_prefix);
std::cout << std::setprecision(15) << res << std::endl;
+ if (print_relative_error)
+ std::cout << "Relative error: " << params.final_relative_error << std::endl;
return 0;
diff --git a/geom_matching/wasserstein/example/wasserstein_dist_dipha.cpp b/geom_matching/wasserstein/example/wasserstein_dist_dipha.cpp
index cd8c61a..2ed9c2c 100644
--- a/geom_matching/wasserstein/example/wasserstein_dist_dipha.cpp
+++ b/geom_matching/wasserstein/example/wasserstein_dist_dipha.cpp
@@ -111,7 +111,7 @@ int main(int argc, char* argv[])
params.max_bids_per_round = ( 10 <= argc ) ? atoi(argv[9]) : 0;
if (params.max_bids_per_round == 0)
- params.max_bids_per_round = std::numeric_limits<size_t>::max();
+ params.max_bids_per_round = std::numeric_limits<decltype(params.max_bids_per_round)>::max();
params.gamma_threshold = (11 <= argc) ? atof(argv[10]) : 0.0;
diff --git a/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp b/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp
index 2f9718e..ab7ff4f 100644
--- a/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp
+++ b/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp
@@ -154,7 +154,7 @@ int main(int argc, char* argv[])
params.max_bids_per_round = ( 9 <= argc ) ? atoi(argv[8]) : 0;
if (params.max_bids_per_round == 0)
- params.max_bids_per_round = std::numeric_limits<size_t>::max();
+ params.max_bids_per_round = std::numeric_limits<decltype(params.max_bids_per_round)>::max();
params.gamma_threshold = (10 <= argc) ? atof(argv[9]) : 0.0;
diff --git a/geom_matching/wasserstein/include/auction_runner_gs.h b/geom_matching/wasserstein/include/auction_runner_gs.h
index fc76987..f8b5a8b 100644
--- a/geom_matching/wasserstein/include/auction_runner_gs.h
+++ b/geom_matching/wasserstein/include/auction_runner_gs.h
@@ -73,6 +73,7 @@ public:
Real initial_epsilon;
Real epsilon_common_ratio; // next epsilon = current epsilon / epsilon_common_ratio
const int max_num_phases; // maximal number of iterations of epsilon-scaling
+ bool tolerate_max_iter_exceeded;
Real weight_adj_const;
Real wasserstein_cost;
Real relative_error;
diff --git a/geom_matching/wasserstein/include/auction_runner_gs.hpp b/geom_matching/wasserstein/include/auction_runner_gs.hpp
index 141cb2c..4ef94db 100644
--- a/geom_matching/wasserstein/include/auction_runner_gs.hpp
+++ b/geom_matching/wasserstein/include/auction_runner_gs.hpp
@@ -68,6 +68,7 @@ AuctionRunnerGS<R, AO, PC>::AuctionRunnerGS(const PC& A,
initial_epsilon(params.initial_epsilon),
epsilon_common_ratio(params.epsilon_common_ratio == 0.0 ? 5.0 : params.epsilon_common_ratio),
max_num_phases(params.max_num_phases),
+ tolerate_max_iter_exceeded(params.tolerate_max_iter_exceeded),
dimension(params.dim),
oracle(bidders, items, params)
#ifdef LOG_AUCTION
@@ -294,7 +295,7 @@ void AuctionRunnerGS<R, AO, PC>::run_auction()
double init_eps = ( initial_epsilon > 0.0 ) ? initial_epsilon : oracle.max_val_ / 4.0 ;
run_auction_phases(max_num_phases, init_eps);
is_distance_computed = true;
- if (relative_error > delta) {
+ if (relative_error > delta and not tolerate_max_iter_exceeded) {
#ifndef FOR_R_TDA
std::cerr << "Maximum iteration number exceeded, exiting. Current result is: ";
std::cerr << pow(wasserstein_cost, 1.0/wasserstein_power) << std::endl;
diff --git a/geom_matching/wasserstein/include/basic_defs_ws.h b/geom_matching/wasserstein/include/basic_defs_ws.h
index 1c5928f..1712ccf 100644
--- a/geom_matching/wasserstein/include/basic_defs_ws.h
+++ b/geom_matching/wasserstein/include/basic_defs_ws.h
@@ -88,8 +88,10 @@ struct AuctionParams
Real epsilon_common_ratio { 5.0 };
Real gamma_threshold { 0.0 }; // for experiments, not in use now
int max_num_phases { std::numeric_limits<decltype(max_num_phases)>::max() };
- size_t max_bids_per_round { 1 }; // imitate Gauss-Seidel is default behaviour
+ int max_bids_per_round { 1 }; // imitate Gauss-Seidel is default behaviour
unsigned int dim { 2 }; // for pure geometric version only; ignored in persistence diagrams
+ Real final_relative_error; // out parameter - after auction terminates, contains the real relative error
+ bool tolerate_max_iter_exceeded { false }; // whether auction should throw an exception on max. iterations exceeded
};
namespace ws
diff --git a/geom_matching/wasserstein/include/opts/opts.h b/geom_matching/wasserstein/include/opts/opts.h
new file mode 100755
index 0000000..74e788b
--- /dev/null
+++ b/geom_matching/wasserstein/include/opts/opts.h
@@ -0,0 +1,353 @@
+/**
+ * Author: Dmitriy Morozov <dmitriy@mrzv.org>
+ * The interface is heavily influenced by GetOptPP (https://code.google.com/p/getoptpp/).
+ */
+
+#ifndef OPTS_OPTS_H
+#define OPTS_OPTS_H
+
+#include <iostream>
+#include <sstream>
+#include <string>
+#include <list>
+#include <vector>
+
+namespace opts {
+
+// Converters
+template<class T>
+struct Converter
+{
+ Converter() {}
+ static
+ T convert(const std::string& val) { std::istringstream iss(val); T res; iss >> res; return res; }
+};
+
+// Type
+template<class T>
+struct Traits
+{
+ static std::string type_string() { return "UNKNOWN TYPE"; }
+};
+
+template<>
+struct Traits<int>
+{
+ static std::string type_string() { return "INT"; }
+};
+
+template<>
+struct Traits<short int>
+{
+ static std::string type_string() { return "SHORT INT"; }
+};
+
+template<>
+struct Traits<unsigned>
+{
+ static std::string type_string() { return "UNSIGNED INT"; }
+};
+
+template<>
+struct Traits<short unsigned>
+{
+ static std::string type_string() { return "SHORT UNSIGNED INT"; }
+};
+
+template<>
+struct Traits<float>
+{
+ static std::string type_string() { return "FLOAT"; }
+};
+
+template<>
+struct Traits<double>
+{
+ static std::string type_string() { return "DOUBLE"; }
+};
+
+template<>
+struct Traits<std::string>
+{
+ static std::string type_string() { return "STRING"; }
+};
+
+
+struct BasicOption
+{
+ BasicOption(char s_,
+ std::string l_,
+ std::string default_,
+ std::string type_,
+ std::string help_):
+ s(s_), l(l_), d(default_), t(type_), help(help_) {}
+
+ int long_size() const { return l.size() + 1 + t.size(); }
+
+ void output(std::ostream& out, int max_long) const
+ {
+ out << " ";
+ if (s)
+ out << '-' << s << ", ";
+ else
+ out << " ";
+
+ out << "--" << l << ' ';
+
+ if (!t.empty())
+ out << t;
+
+ for (int i = long_size(); i < max_long; ++i)
+ out << ' ';
+
+ out << " " << help;
+
+ if (!d.empty())
+ {
+ out << " [default: " << d << "]";
+ }
+ out << '\n';
+ }
+
+ char s;
+ std::string l;
+ std::string d;
+ std::string t;
+ std::string help;
+};
+
+// Option
+template<class T>
+struct OptionContainer: public BasicOption
+{
+ OptionContainer(char s_,
+ const std::string& l_,
+ T& var_,
+ const std::string& help_,
+ const std::string& type_ = Traits<T>::type_string()):
+ BasicOption(s_, l_, default_value(var_), type_, help_),
+ var(&var_) {}
+
+ static
+ std::string default_value(const T& def)
+ {
+ std::ostringstream oss;
+ oss << def;
+ return oss.str();
+ }
+
+ void parse(std::list<std::string>& args) const
+ {
+ std::string short_opt = "-"; short_opt += s;
+ std::string long_opt = "--" + l;
+ for (std::list<std::string>::iterator cur = args.begin(); cur != args.end(); ++cur)
+ {
+ if (*cur == short_opt || *cur == long_opt)
+ {
+ cur = args.erase(cur);
+ if (cur != args.end())
+ {
+ *var = Converter<T>::convert(*cur);
+ cur = args.erase(cur);
+ break; // finds first occurrence
+ }
+ else
+ break; // if the last option's value is missing, it remains default
+
+ }
+ }
+ }
+
+ T* var;
+};
+
+template<class T>
+struct OptionContainer< std::vector<T> >: public BasicOption
+{
+ OptionContainer(char s_,
+ const std::string& l_,
+ std::vector<T>& var_,
+ const std::string& help_,
+ const std::string& type_ = "SEQUENCE"):
+ BasicOption(s_, l_, default_value(var_), type_, help_),
+ var(&var_) { }
+
+ static
+ std::string default_value(const std::vector<T>& def)
+ {
+ std::ostringstream oss;
+ oss << "(";
+ if (def.size())
+ oss << def[0];
+ for (int i = 1; i < def.size(); ++i)
+ oss << ", " << def[i];
+ oss << ")";
+ return oss.str();
+ }
+
+ void parse(std::list<std::string>& args) const
+ {
+ std::string short_opt = "-"; short_opt += s;
+ std::string long_opt = "--" + l;
+ for (std::list<std::string>::iterator cur = args.begin(); cur != args.end(); ++cur)
+ {
+ if (*cur == short_opt || *cur == long_opt)
+ {
+ cur = args.erase(cur);
+ if (cur != args.end())
+ {
+ var->push_back(Converter<T>::convert(*cur));
+ cur = args.erase(cur);
+ }
+ --cur;
+ }
+ }
+ }
+
+ std::vector<T>* var;
+};
+
+
+template<class T>
+OptionContainer<T>
+Option(char s, const std::string& l, T& var, const std::string& help) { return OptionContainer<T>(s, l, var, help); }
+
+template<class T>
+OptionContainer<T>
+Option(char s, const std::string& l, T& var,
+ const std::string& type, const std::string& help) { return OptionContainer<T>(s, l, var, help, type); }
+
+template<class T>
+OptionContainer<T>
+Option(const std::string& l, T& var, const std::string& help) { return OptionContainer<T>(0, l, var, help); }
+
+template<class T>
+OptionContainer<T>
+Option(const std::string& l, T& var,
+ const std::string& type, const std::string& help) { return OptionContainer<T>(0, l, var, help, type); }
+
+// Present
+struct PresentContainer: public BasicOption
+{
+ PresentContainer(char s, const std::string& l, const std::string& help):
+ BasicOption(s,l,"","",help) {}
+};
+
+inline
+PresentContainer
+Present(char s, const std::string& l, const std::string& help) { return PresentContainer(s, l, help); }
+
+inline
+PresentContainer
+Present(const std::string& l, const std::string& help) { return PresentContainer(0, l, help); }
+
+// PosOption
+template<class T>
+struct PosOptionContainer
+{
+ PosOptionContainer(T& var_):
+ var(&var_) {}
+
+ bool parse(std::list<std::string>& args) const
+ {
+ if (args.empty())
+ return false;
+
+ *var = Converter<T>::convert(args.front());
+ args.pop_front();
+ return true;
+ }
+
+ T* var;
+};
+
+template<class T>
+PosOptionContainer<T>
+PosOption(T& var) { return PosOptionContainer<T>(var); }
+
+
+// Options
+struct Options
+{
+ Options(int argc_, char** argv_):
+ args(argv_ + 1, argv_ + argc_),
+ failed(false) {}
+
+ template<class T>
+ Options& operator>>(const OptionContainer<T>& oc);
+ bool operator>>(const PresentContainer& pc);
+ template<class T>
+ Options& operator>>(const PosOptionContainer<T>& poc);
+
+ operator bool() { return !failed; }
+
+
+ friend
+ std::ostream&
+ operator<<(std::ostream& out, const Options& ops)
+ {
+ int max_long = 0;
+ for (std::list<BasicOption>::const_iterator cur = ops.options.begin();
+ cur != ops.options.end();
+ ++cur)
+ {
+ int cur_long = cur->long_size();
+ if (cur_long > max_long)
+ max_long = cur_long;
+ }
+
+ out << "Options:\n";
+ for (std::list<BasicOption>::const_iterator cur = ops.options.begin();
+ cur != ops.options.end();
+ ++cur)
+ cur->output(out, max_long);
+
+ return out;
+ }
+
+
+ private:
+ std::list<std::string> args;
+ std::list<BasicOption> options;
+ bool failed;
+};
+
+template<class T>
+Options&
+Options::operator>>(const OptionContainer<T>& oc)
+{
+ options.push_back(oc);
+ oc.parse(args);
+ return *this;
+}
+
+inline
+bool
+Options::operator>>(const PresentContainer& pc)
+{
+ options.push_back(pc);
+
+ for(std::list<std::string>::iterator cur = args.begin(); cur != args.end(); ++cur)
+ {
+ std::string short_opt = "-"; short_opt += pc.s;
+ std::string long_opt = "--" + pc.l;
+ if (*cur == short_opt || *cur == long_opt)
+ {
+ args.erase(cur);
+ return true;
+ }
+ }
+ return false;
+}
+
+template<class T>
+Options&
+Options::operator>>(const PosOptionContainer<T>& poc)
+{
+ failed = !poc.parse(args);
+ return *this;
+}
+
+}
+
+#endif
diff --git a/geom_matching/wasserstein/include/wasserstein.h b/geom_matching/wasserstein/include/wasserstein.h
index a24bada..35d0bf6 100644
--- a/geom_matching/wasserstein/include/wasserstein.h
+++ b/geom_matching/wasserstein/include/wasserstein.h
@@ -212,7 +212,7 @@ namespace ws
template<class RealType>
inline RealType wasserstein_cost_vec(const std::vector<DiagramPoint<RealType>>& A,
const std::vector<DiagramPoint<RealType>>& B,
- const AuctionParams<RealType>& params,
+ AuctionParams<RealType>& params,
const std::string& _log_filename_prefix)
{
if (params.wasserstein_power < 1.0) {
@@ -237,6 +237,7 @@ namespace ws
AuctionRunnerGS<RealType> auction(A, B, params, _log_filename_prefix);
auction.run_auction();
result = auction.get_wasserstein_cost();
+ params.final_relative_error = auction.get_relative_error();
return result;
}
@@ -248,7 +249,7 @@ template<class PairContainer>
inline typename DiagramTraits<PairContainer>::RealType
wasserstein_cost(const PairContainer& A,
const PairContainer& B,
- const AuctionParams< typename DiagramTraits<PairContainer>::RealType >& params,
+ AuctionParams< typename DiagramTraits<PairContainer>::RealType >& params,
const std::string& _log_filename_prefix = "")
{
using Traits = DiagramTraits<PairContainer>;
@@ -335,7 +336,7 @@ template<class PairContainer>
inline typename DiagramTraits<PairContainer>::RealType
wasserstein_dist(const PairContainer& A,
const PairContainer& B,
- const AuctionParams<typename DiagramTraits<PairContainer>::RealType> params,
+ AuctionParams<typename DiagramTraits<PairContainer>::RealType>& params,
const std::string& _log_filename_prefix = "")
{
using Real = typename DiagramTraits<PairContainer>::RealType;