From ba25f264b5d309efcf77a6b72d1b784ae97f741f Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Wed, 6 Feb 2019 22:45:18 +0100 Subject: Switch to opts, tolerate max-iterations-exceeded. 1. Use opts.h for command-line parsing in wasserstein_dist. 2. Real relative error at the end of auction is stored in params, params is passed by reference. 3. If -e command line option is given to wasserstein_dist, relative error will be printed. 4. If -t option is given to wasserstein_dist, no exception will be thrown, if maximum number of iterations is exceeded. 5. Run wasserstein_dist -h to see all options. --- .../wasserstein/example/wasserstein_dist.cpp | 67 +++++++++++++++------- .../wasserstein/example/wasserstein_dist_dipha.cpp | 2 +- .../example/wasserstein_dist_point_cloud.cpp | 2 +- 3 files changed, 48 insertions(+), 23 deletions(-) (limited to 'geom_matching/wasserstein/example') diff --git a/geom_matching/wasserstein/example/wasserstein_dist.cpp b/geom_matching/wasserstein/example/wasserstein_dist.cpp index fcbc641..25e1f68 100644 --- a/geom_matching/wasserstein/example/wasserstein_dist.cpp +++ b/geom_matching/wasserstein/example/wasserstein_dist.cpp @@ -31,6 +31,8 @@ derivative works thereof, in binary and source code form. #include #include +#include "opts/opts.h" + //#define LOG_AUCTION //#include "auction_runner_fr.h" @@ -47,23 +49,54 @@ int main(int argc, char* argv[]) PairVector diagramA, diagramB; hera::AuctionParams params; + params.max_num_phases = 800; + + opts::Options ops(argc, argv); + ops >> opts::Option('q', "degree", params.wasserstein_power, "Wasserstein degree") + >> opts::Option('d', "error", params.delta, "Relative error") + >> opts::Option('p', "internal-p", params.internal_p, "Internal norm") + >> opts::Option("initial-epsilon", params.initial_epsilon, "Initial epsilon") + >> opts::Option("epsilon-factor", params.epsilon_common_ratio, "Epsilon factor") + >> opts::Option("max-bids-per-round", params.max_bids_per_round, "Maximal number of bids per round") + >> opts::Option('m', "max-rounds", params.max_num_phases, "Maximal number of iterations"); + + + bool print_relative_error = ops >> opts::Present('e', "--print-error", "Print real relative error"); + + params.tolerate_max_iter_exceeded = ops >> opts::Present('t', "tolerate", "Suppress max-iterations-exceeded error and print the best result."); + + std::string dgm_fname_1, dgm_fname_2; + bool dgm_1_given = (ops >> opts::PosOption(dgm_fname_1)); + bool dgm_2_given = (ops >> opts::PosOption(dgm_fname_2)); - if (argc < 3 ) { - std::cerr << "Usage: " << argv[0] << " file1 file2 [wasserstein_degree] [relative_error] [internal norm] [initial epsilon] [epsilon_factor] [max_bids_per_round] [gamma_threshold][log_filename_prefix]. By default power is 1.0, relative error is 0.01, internal norm is l_infinity, initall epsilon is chosen automatically, epsilon factor is 5.0, Jacobi variant is used (max bids per round is maximal), gamma_threshold = 0.0." << std::endl; + //std::cout << "q = " << params.wasserstein_power << ", delta = " << params.delta << ", p = " << params.internal_p << ", max_round = " << params.max_num_phases << std::endl; + //std::cout << "print relative error: " << print_relative_error << std::endl; + //std::cout << "dgm1: " << dgm_fname_1 << std::endl; + //std::cout << "dgm2: " << dgm_fname_2 << std::endl; + + if (not dgm_1_given or not dgm_2_given) { + std::cerr << "Usage: " << argv[0] << " file1 file2 " << std::endl; + std::cerr << "compute Wasserstein distance between persistence diagrams in file1 and file2.\n"; + std::cerr << ops << std::endl; return 1; } - if (!hera::read_diagram_point_set(argv[1], diagramA)) { + if (ops >> opts::Present('h', "help", "show help message")) { + std::cout << "Usage: " << argv[0] << " file1 file2 " << std::endl; + std::cout << "compute Wasserstein distance between persistence diagrams in file1 and file2.\n"; + std::cout << ops << std::endl; + } + + if (!hera::read_diagram_point_set(dgm_fname_1, diagramA)) { std::exit(1); } - if (!hera::read_diagram_point_set(argv[2], diagramB)) { + if (!hera::read_diagram_point_set(dgm_fname_2, diagramB)) { std::exit(1); } - params.wasserstein_power = (4 <= argc) ? atof(argv[3]) : 1.0; if (params.wasserstein_power < 1.0) { - std::cerr << "The third argument (wasserstein_degree) was \"" << argv[3] << "\", must be a number >= 1.0. Cannot proceed. " << std::endl; + std::cerr << "Wasserstein_degree was \"" << params.wasserstein_power << "\", must be a number >= 1.0. Cannot proceed. " << std::endl; std::exit(1); } @@ -72,50 +105,40 @@ int main(int argc, char* argv[]) } //default relative error: 1% - params.delta = (5 <= argc) ? atof(argv[4]) : 0.01; if ( params.delta <= 0.0) { - std::cerr << "The 4th argument (relative error) was \"" << argv[4] << "\", must be a number > 0.0. Cannot proceed. " << std::endl; + std::cerr << "relative error was \"" << params.delta << "\", must be a number > 0.0. Cannot proceed. " << std::endl; std::exit(1); } // default for internal metric is l_infinity - params.internal_p = ( 6 <= argc ) ? atof(argv[5]) : hera::get_infinity(); if (std::isinf(params.internal_p)) { params.internal_p = hera::get_infinity(); } if (not hera::is_p_valid_norm(params.internal_p)) { - std::cerr << "The 5th argument (internal norm) was \"" << argv[5] << "\", must be a number >= 1.0 or inf. Cannot proceed. " << std::endl; + std::cerr << "internal-p was \"" << params.internal_p << "\", must be a number >= 1.0 or inf. Cannot proceed. " << std::endl; std::exit(1); } // if you want to specify initial value for epsilon and the factor // for epsilon-scaling - params.initial_epsilon= ( 7 <= argc ) ? atof(argv[6]) : 0.0 ; - if (params.initial_epsilon < 0.0) { - std::cerr << "The 6th argument (initial epsilon) was \"" << argv[6] << "\", must be a non-negative number. Cannot proceed." << std::endl; + std::cerr << "initial-epsilon was \"" << params.initial_epsilon << "\", must be a non-negative number. Cannot proceed." << std::endl; std::exit(1); } - params.epsilon_common_ratio = ( 8 <= argc ) ? atof(argv[7]) : 0.0 ; if (params.epsilon_common_ratio <= 1.0 and params.epsilon_common_ratio != 0.0) { - std::cerr << "The 7th argument (epsilon factor) was \"" << argv[7] << "\", must be a number greater than 1. Cannot proceed." << std::endl; + std::cerr << "The 7th argument (epsilon factor) was \"" << params.epsilon_common_ratio << "\", must be a number greater than 1. Cannot proceed." << std::endl; std::exit(1); } - - params.max_bids_per_round = ( 9 <= argc ) ? atoi(argv[8]) : 0; if (params.max_bids_per_round == 0) - params.max_bids_per_round = std::numeric_limits::max(); - + params.max_bids_per_round = std::numeric_limits::max(); - params.gamma_threshold = (10 <= argc) ? atof(argv[9]) : 0.0; std::string log_filename_prefix = ( 11 <= argc ) ? argv[10] : ""; - params.max_num_phases = 800; #ifdef LOG_AUCTION spdlog::set_level(spdlog::level::info); @@ -124,6 +147,8 @@ int main(int argc, char* argv[]) double res = hera::wasserstein_dist(diagramA, diagramB, params, log_filename_prefix); std::cout << std::setprecision(15) << res << std::endl; + if (print_relative_error) + std::cout << "Relative error: " << params.final_relative_error << std::endl; return 0; diff --git a/geom_matching/wasserstein/example/wasserstein_dist_dipha.cpp b/geom_matching/wasserstein/example/wasserstein_dist_dipha.cpp index cd8c61a..2ed9c2c 100644 --- a/geom_matching/wasserstein/example/wasserstein_dist_dipha.cpp +++ b/geom_matching/wasserstein/example/wasserstein_dist_dipha.cpp @@ -111,7 +111,7 @@ int main(int argc, char* argv[]) params.max_bids_per_round = ( 10 <= argc ) ? atoi(argv[9]) : 0; if (params.max_bids_per_round == 0) - params.max_bids_per_round = std::numeric_limits::max(); + params.max_bids_per_round = std::numeric_limits::max(); params.gamma_threshold = (11 <= argc) ? atof(argv[10]) : 0.0; diff --git a/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp b/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp index 2f9718e..ab7ff4f 100644 --- a/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp +++ b/geom_matching/wasserstein/example/wasserstein_dist_point_cloud.cpp @@ -154,7 +154,7 @@ int main(int argc, char* argv[]) params.max_bids_per_round = ( 9 <= argc ) ? atoi(argv[8]) : 0; if (params.max_bids_per_round == 0) - params.max_bids_per_round = std::numeric_limits::max(); + params.max_bids_per_round = std::numeric_limits::max(); params.gamma_threshold = (10 <= argc) ? atof(argv[9]) : 0.0; -- cgit v1.2.3