From 5ebf7142b00554b3f5d151c8b4e81b746962a5b8 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Fri, 6 Mar 2020 18:29:25 +0100 Subject: Reorganize matching_dist code, minor fixes. --- matching/CMakeLists.txt | 26 +- matching/README | 5 - matching/README.md | 5 + matching/example/matching_dist.cpp | 396 +++++++++++++++++++++ matching/include/matching_distance.h | 2 +- matching/src/main.cpp | 382 -------------------- matching/src/test_generator.cpp | 211 ----------- matching/src/tests/prism_1.bif | 28 -- matching/src/tests/prism_2.bif | 29 -- matching/src/tests/test_bifiltration.cpp | 36 -- matching/src/tests/test_bifiltration_1.txt | 1 - .../tests/test_bifiltration_full_triangle_rene.txt | 1 - matching/src/tests/test_common.cpp | 156 -------- matching/src/tests/test_list.txt | 1 - matching/src/tests/test_matching_distance.cpp | 167 --------- matching/src/tests/tests_main.cpp | 2 - matching/tests/prism_1.bif | 28 ++ matching/tests/prism_2.bif | 29 ++ matching/tests/test_bifiltration.cpp | 36 ++ matching/tests/test_bifiltration_1.txt | 1 + .../tests/test_bifiltration_full_triangle_rene.txt | 1 + matching/tests/test_common.cpp | 156 ++++++++ matching/tests/test_list.txt | 1 + matching/tests/test_matching_distance.cpp | 167 +++++++++ matching/tests/tests_main.cpp | 2 + 25 files changed, 831 insertions(+), 1038 deletions(-) delete mode 100644 matching/README create mode 100644 matching/README.md create mode 100644 matching/example/matching_dist.cpp delete mode 100644 matching/src/main.cpp delete mode 100644 matching/src/test_generator.cpp delete mode 100644 matching/src/tests/prism_1.bif delete mode 100644 matching/src/tests/prism_2.bif delete mode 100644 matching/src/tests/test_bifiltration.cpp delete mode 120000 matching/src/tests/test_bifiltration_1.txt delete mode 120000 matching/src/tests/test_bifiltration_full_triangle_rene.txt delete mode 100644 matching/src/tests/test_common.cpp delete mode 100644 matching/src/tests/test_list.txt delete mode 100644 matching/src/tests/test_matching_distance.cpp delete mode 100644 matching/src/tests/tests_main.cpp create mode 100644 matching/tests/prism_1.bif create mode 100644 matching/tests/prism_2.bif create mode 100644 matching/tests/test_bifiltration.cpp create mode 120000 matching/tests/test_bifiltration_1.txt create mode 120000 matching/tests/test_bifiltration_full_triangle_rene.txt create mode 100644 matching/tests/test_common.cpp create mode 100644 matching/tests/test_list.txt create mode 100644 matching/tests/test_matching_distance.cpp create mode 100644 matching/tests/tests_main.cpp diff --git a/matching/CMakeLists.txt b/matching/CMakeLists.txt index 121e25c..3ee0f6b 100644 --- a/matching/CMakeLists.txt +++ b/matching/CMakeLists.txt @@ -22,7 +22,6 @@ set(CMAKE_CXX_STANDARD 14) if (NOT WIN32) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") - #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS} -Wall pedantic -Wextra ") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS} -Wall -Wextra ") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -ggdb -D_GLIBCXX_DEBUG") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELEASE} -O2 -g -ggdb") @@ -31,32 +30,23 @@ endif (NOT WIN32) file(GLOB BT_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.hpp) file(GLOB MD_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp) -file(GLOB SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) - -file(GLOB SRC_TEST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/tests/*.cpp) - +file(GLOB SRC_TEST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/tests/*.cpp) + find_package(Threads) set(libraries ${libraries} "stdc++fs" ${CMAKE_THREAD_LIBS_INIT}) find_package(OpenMP) if (OPENMP_FOUND) -set(libraries ${libraries} ${OpenMP_CXX_LIBRARIES}) -set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") -set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") -set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") + set(libraries ${libraries} ${OpenMP_CXX_LIBRARIES}) + set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") endif() -add_executable(matching_distance "src/main.cpp" ${MD_HEADERS} ${BT_HEADERS} ) -target_link_libraries(matching_distance PUBLIC ${libraries}) +add_executable(matching_dist "example/matching_dist.cpp" ${MD_HEADERS} ${BT_HEADERS} ) +target_link_libraries(matching_dist PUBLIC ${libraries}) add_executable(matching_distance_test ${SRC_TEST_FILES} ${BT_HEADERS} ${MD_HEADERS}) target_link_libraries(matching_distance_test PUBLIC ${libraries}) - -add_executable(test_generator "src/test_generator.cpp" - ${MD_HEADERS} - ${BT_HEADERS}) -target_link_libraries(test_generator PUBLIC ${libraries}) - -#add_executable(matching_distance "include/main.cpp" "src/box.cpp" "src/common_util.cpp" "src/line.cpp" "src/persistence_module.hpp" ${BT_HEADERS} ${MD_HEADERS}) diff --git a/matching/README b/matching/README deleted file mode 100644 index 7dc1874..0000000 --- a/matching/README +++ /dev/null @@ -1,5 +0,0 @@ -Matching distance between bifiltrations. - -Currently supports only 1-critical bi-filtrations -in PHAT-like format: boundary matrix + critical values -of a simplex in each row diff --git a/matching/README.md b/matching/README.md new file mode 100644 index 0000000..7dc1874 --- /dev/null +++ b/matching/README.md @@ -0,0 +1,5 @@ +Matching distance between bifiltrations. + +Currently supports only 1-critical bi-filtrations +in PHAT-like format: boundary matrix + critical values +of a simplex in each row diff --git a/matching/example/matching_dist.cpp b/matching/example/matching_dist.cpp new file mode 100644 index 0000000..13e5c6d --- /dev/null +++ b/matching/example/matching_dist.cpp @@ -0,0 +1,396 @@ +#include "common_defs.h" + +#include +#include +#include +#include + +#ifdef MD_EXPERIMENTAL_TIMING +#include +#endif + +#include "opts/opts.h" +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +//#include "persistence_module.h" +#include "bifiltration.h" +#include "box.h" +#include "matching_distance.h" + +using Real = double; + +using namespace md; + +namespace fs = std::experimental::filesystem; + +#ifdef PRINT_HEAT_MAP +void print_heat_map(const md::HeatMaps& hms, std::string fname, const CalculationParams& params) +{ + spd::debug("Entered print_heat_map"); + std::set mu_vals, lambda_vals; + auto hm_iter = hms.end(); + --hm_iter; + int max_level = hm_iter->first; + + int level_cardinality = 4; + for(int i = 0; i < params.initialization_depth; ++i) { + level_cardinality *= 4; + } + for(int i = params.initialization_depth + 1; i <= max_level; ++i) { + spd::debug("hms.at({}).size = {}, must be {}", i, hms.at(i).size(), level_cardinality); + assert(static_cast(hms.at(i).size()) == level_cardinality); + level_cardinality *= 4; + } + + std::map, Real> hm_x_flat, hm_x_steep, hm_y_flat, hm_y_steep; + + for(const auto& dual_point_value_pair : hms.at(max_level)) { + const DualPoint& k = dual_point_value_pair.first; + spd::debug("HM DP: {}", k); + mu_vals.insert(k.mu()); + lambda_vals.insert(k.lambda()); + } + + std::vector lambda_vals_vec(lambda_vals.begin(), lambda_vals.end()); + std::vector mu_vals_vec(mu_vals.begin(), mu_vals.end()); + + std::ofstream ofs {fname}; + if (not ofs.good()) { + std::cerr << "Cannot write heat map to file " << fname << std::endl; + throw std::runtime_error("Cannot open file for writing heat map"); + } + + std::vector> heatmap_to_print(2 * mu_vals_vec.size(), + std::vector(2 * lambda_vals_vec.size(), 0.0)); + + for(auto axis_type : {AxisType::x_type, AxisType::y_type}) { + bool is_x_type = axis_type == AxisType::x_type; + for(auto angle_type : {AngleType::flat, AngleType::steep}) { + bool is_flat = angle_type == AngleType::flat; + + int mu_idx_begin, mu_idx_end; + + if (is_x_type) { + mu_idx_begin = mu_vals_vec.size() - 1; + mu_idx_end = -1; + } else { + mu_idx_begin = 0; + mu_idx_end = mu_vals_vec.size(); + } + + int lambda_idx_begin, lambda_idx_end; + + if (is_flat) { + lambda_idx_begin = 0; + lambda_idx_end = lambda_vals_vec.size(); + } else { + lambda_idx_begin = lambda_vals_vec.size() - 1; + lambda_idx_end = -1; + } + + int mu_idx_final = is_x_type ? 0 : mu_vals_vec.size(); + + for(int mu_idx = mu_idx_begin; mu_idx != mu_idx_end; (mu_idx_begin < mu_idx_end) ? mu_idx++ : mu_idx--) { + Real mu = mu_vals_vec.at(mu_idx); + + if (mu == 0.0 and axis_type == AxisType::x_type) + continue; + + int lambda_idx_final = is_flat ? 0 : lambda_vals_vec.size(); + + for(int lambda_idx = lambda_idx_begin; + lambda_idx != lambda_idx_end; + (lambda_idx_begin < lambda_idx_end) ? lambda_idx++ : lambda_idx--) { + + Real lambda = lambda_vals_vec.at(lambda_idx); + + if (lambda == 0.0 and angle_type == AngleType::flat) + continue; + + DualPoint dp(axis_type, angle_type, lambda, mu); + Real dist_value = hms.at(max_level).at(dp); + + heatmap_to_print.at(mu_idx_final).at(lambda_idx_final) = dist_value; + +// fmt::print("HM, dp = {}, mu_idx_final = {}, lambda_idx_final = {}, value = {}\n", dp, mu_idx_final, +// lambda_idx_final, dist_value); + + lambda_idx_final++; + } + mu_idx_final++; + } + } + } + + for(size_t m_idx = 0; m_idx < heatmap_to_print.size(); ++m_idx) { + for(size_t l_idx = 0; l_idx < heatmap_to_print.at(m_idx).size(); ++l_idx) { + ofs << heatmap_to_print.at(m_idx).at(l_idx) << " "; + } + ofs << std::endl; + } + + ofs.close(); + spd::debug("Exit print_heat_map"); +} +#endif + +int main(int argc, char** argv) +{ + spdlog::set_level(spdlog::level::info); + + using opts::Option; + using opts::PosOption; + opts::Options ops; + + CalculationParams params; + + bool help = false; + +#ifdef MD_EXPERIMENTAL_TIMING + bool no_stop_asap = false; +#endif + + +#ifdef PRINT_HEAT_MAP + bool heatmap_only = false; +#endif + + // default values are the best for practical use + // if compiled with MD_EXPERIMENTAL_TIMING, user can supply + // different bounds and traverse strategies + // separated by commas, all combinations will be timed. + // See corresponding >> operators in matching_distance.h + // for possible values. + std::string bounds_list_str = "local_combined"; + std::string traverse_list_str = "BFS"; + + ops >> Option('m', "max-iterations", params.max_depth, "maximal number of iterations (refinements)") + >> Option('e', "max-error", params.delta, "error threshold") + >> Option('d', "dim", params.dim, "dim") + >> Option('i', "initial-depth", params.initialization_depth, "initialization depth") +#ifdef MD_EXPERIMENTAL_TIMING + >> Option("no-stop-asap", no_stop_asap, + "don't stop looping over points, if cell cannot be pruned (asap is on by default)") + >> Option("bounds", bounds_list_str, "bounds to use, separated by ,") + >> Option("traverse", traverse_list_str, "traverse strategy to use, separated by ,") +#endif +#ifdef PRINT_HEAT_MAP + >> Option("heatmap-only", heatmap_only, "only save heatmap (bruteforce)") +#endif + >> Option('h', "help", help, "show help message"); + + std::string fname_a; + std::string fname_b; + + if (!ops.parse(argc, argv) || help || !(ops >> PosOption(fname_a) >> PosOption(fname_b))) { + std::cerr << "Usage: " << argv[0] << " bifiltration-file-1 bifiltration-file-2\n" << ops << std::endl; + return 1; + } + +#ifdef MD_EXPERIMENTAL_TIMING + params.stop_asap = not no_stop_asap; +#else + params.stop_asap = true; +#endif + + auto bounds_list = split_by_delim(bounds_list_str, ','); + auto traverse_list = split_by_delim(traverse_list_str, ','); + + Bifiltration bif_a(fname_a); + Bifiltration bif_b(fname_b); + + bif_a.sanity_check(); + bif_b.sanity_check(); + + spd::debug("Read bifiltrations {} {}", fname_a, fname_b); + + std::vector bound_strategies; + std::vector traverse_strategies; + + for(std::string s : bounds_list) { + bound_strategies.push_back(bs_from_string(s)); + } + + for(std::string s : traverse_list) { + traverse_strategies.push_back(ts_from_string(s)); + } + + +#ifdef MD_EXPERIMENTAL_TIMING + + for(auto bs : bound_strategies) { + for(auto ts : traverse_strategies) { + spd::info("Will test combination {} {}", bs, ts); + } + } + + struct ExperimentResult { + CalculationParams params {CalculationParams()}; + int n_hera_calls {0}; + double total_milliseconds_elapsed {0}; + Real distance {0}; + Real actual_error {std::numeric_limits::max()}; + int actual_max_depth {0}; + + int x_wins {0}; + int y_wins {0}; + int ad_wins {0}; + + int seconds_elapsed() const + { + return static_cast(total_milliseconds_elapsed / 1000); + } + + double savings_ratio_old() const + { + long int max_possible_calls = 0; + long int calls_on_level = 4; + for(int i = 0; i <= actual_max_depth; ++i) { + max_possible_calls += calls_on_level; + calls_on_level *= 4; + } + return static_cast(n_hera_calls) / static_cast(max_possible_calls); + } + + double savings_ratio() const + { + return static_cast(n_hera_calls) / calls_on_actual_max_depth(); + } + + long long int calls_on_actual_max_depth() const + { + long long int result = 1; + for(int i = 0; i < actual_max_depth; ++i) { + result *= 4; + } + return result; + } + + ExperimentResult() { } + + ExperimentResult(CalculationParams p, int nhc, double tme, double d) + : + params(p), n_hera_calls(nhc), total_milliseconds_elapsed(tme), distance(d) { } + }; + + const int n_repetitions = 1; + +#ifdef PRINT_HEAT_MAP + if (heatmap_only) { + bound_strategies.clear(); + bound_strategies.push_back(BoundStrategy::bruteforce); + traverse_strategies.clear(); + traverse_strategies.push_back(TraverseStrategy::breadth_first); + } +#endif + + std::map, ExperimentResult> results; + for(BoundStrategy bound_strategy : bound_strategies) { + for(TraverseStrategy traverse_strategy : traverse_strategies) { + CalculationParams params_experiment; + params_experiment.bound_strategy = bound_strategy; + params_experiment.traverse_strategy = traverse_strategy; + params_experiment.max_depth = params.max_depth; + params_experiment.initialization_depth = params.initialization_depth; + params_experiment.delta = params.delta; + params_experiment.dim = params.dim; + params_experiment.hera_epsilon = params.hera_epsilon; + params_experiment.stop_asap = params.stop_asap; + + if (traverse_strategy == TraverseStrategy::depth_first and bound_strategy == BoundStrategy::bruteforce) + continue; + + // if bruteforce, clamp max iterations number to 7, + // save user-provided max_iters in user_max_iters and restore it later. + // remember: params is passed by reference to return real relative error and heat map + + int user_max_iters = params.max_depth; +#ifdef PRINT_HEAT_MAP + if (bound_strategy == BoundStrategy::bruteforce and not heatmap_only) { + params_experiment.max_depth = std::min(7, params.max_depth); + } +#endif + double total_milliseconds_elapsed = 0; + int total_n_hera_calls = 0; + Real dist; + for(int i = 0; i < n_repetitions; ++i) { + spd::debug("Processing bound_strategy {}, traverse_strategy {}, iteration = {}", bound_strategy, + traverse_strategy, i); + auto t1 = std::chrono::high_resolution_clock().now(); + dist = matching_distance(bif_a, bif_b, params_experiment); + auto t2 = std::chrono::high_resolution_clock().now(); + total_milliseconds_elapsed += std::chrono::duration_cast( + t2 - t1).count(); + total_n_hera_calls += params_experiment.n_hera_calls; + } + + auto key = std::make_tuple(bound_strategy, traverse_strategy); + results[key].params = params_experiment; + results[key].n_hera_calls = total_n_hera_calls / n_repetitions; + results[key].total_milliseconds_elapsed = total_milliseconds_elapsed / n_repetitions; + results[key].distance = dist; + results[key].actual_error = params_experiment.actual_error; + results[key].actual_max_depth = params_experiment.actual_max_depth; + + spd::info( + "Done (bound = {}, traverse = {}), n_hera_calls = {}, time = {} sec, d = {}, error = {}, savings = {}, max_depth = {}", + bound_strategy, traverse_strategy, results[key].n_hera_calls, results[key].seconds_elapsed(), + dist, + params_experiment.actual_error, results[key].savings_ratio(), results[key].actual_max_depth); + + if (bound_strategy == BoundStrategy::bruteforce) { params_experiment.max_depth = user_max_iters; } + +#ifdef PRINT_HEAT_MAP + if (bound_strategy == BoundStrategy::bruteforce) { + fs::path fname_a_path {fname_a.c_str()}; + fs::path fname_b_path {fname_b.c_str()}; + fs::path fname_a_wo = fname_a_path.filename(); + fs::path fname_b_wo = fname_b_path.filename(); + std::string heat_map_fname = fmt::format("{0}_{1}_dim_{2}_weighted_values_xyp.txt", fname_a_wo.string(), + fname_b_wo.string(), params_experiment.dim); + fs::path path_hm = fname_a_path.replace_filename(fs::path(heat_map_fname.c_str())); + spd::debug("Saving heatmap to {}", heat_map_fname); + print_heat_map(params_experiment.heat_maps, path_hm.string(), params); + } +#endif + spd::debug("Finished processing bound_strategy {}", bound_strategy); + } + } + +// std::cout << "File_1;File_2;Boundstrategy;TraverseStrategy;InitalDepth;NHeraCalls;SavingsRatio;Time;Distance;Error;PushStrategy;MaxDepth;CallsOnMaxDepth;Delta;Dimension" << std::endl; + for(auto bs : bound_strategies) { + for(auto ts : traverse_strategies) { + auto key = std::make_tuple(bs, ts); + + fs::path fname_a_path {fname_a.c_str()}; + fs::path fname_b_path {fname_b.c_str()}; + fs::path fname_a_wo = fname_a_path.filename(); + fs::path fname_b_wo = fname_b_path.filename(); + + std::cout << fname_a_wo.string() << ";" << fname_b_wo.string() << ";" << bs << ";" << ts << ";"; + std::cout << results[key].params.initialization_depth << ";"; + std::cout << results[key].n_hera_calls << ";" + << results[key].savings_ratio() << ";" + << results[key].total_milliseconds_elapsed << ";" + << results[key].distance << ";" + << results[key].actual_error << ";" + << "xyp" << ";" + << results[key].actual_max_depth << ";" + << results[key].calls_on_actual_max_depth() << ";" + << params.delta << ";" + << params.dim + << std::endl; + } + } +#else + params.bound_strategy = bound_strategies.back(); + params.traverse_strategy = traverse_strategies.back(); + + spd::debug("Will use {} bound, {} traverse strategy", params.bound_strategy, params.traverse_strategy); + + Real dist = matching_distance(bif_a, bif_b, params); + std::cout << dist << std::endl; +#endif + return 0; +} diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h index 5be34c7..276cc3a 100644 --- a/matching/include/matching_distance.h +++ b/matching/include/matching_distance.h @@ -169,7 +169,7 @@ namespace md { bool print_stats { false }; #ifdef MD_PRINT_HEAT_MAP - HeatMaps heat_maps; + HeatMaps heat_maps; #endif }; diff --git a/matching/src/main.cpp b/matching/src/main.cpp deleted file mode 100644 index 2093457..0000000 --- a/matching/src/main.cpp +++ /dev/null @@ -1,382 +0,0 @@ -#include "common_defs.h" - -#include -#include -#include -#include - -#ifdef EXPERIMENTAL_TIMING -#include -#endif - -#include "opts/opts.h" -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" - -//#include "persistence_module.h" -#include "bifiltration.h" -#include "box.h" -#include "matching_distance.h" - -using Real = double; - -using namespace md; - -namespace fs = std::experimental::filesystem; - -void force_instantiation() -{ - DualBox db; - std::cout << db; -} - -#ifdef PRINT_HEAT_MAP -void print_heat_map(const md::HeatMaps& hms, std::string fname, const CalculationParams& params) -{ - spd::debug("Entered print_heat_map"); - std::set mu_vals, lambda_vals; - auto hm_iter = hms.end(); - --hm_iter; - int max_level = hm_iter->first; - - int level_cardinality = 4; - for(int i = 0; i < params.initialization_depth; ++i) { - level_cardinality *= 4; - } - for(int i = params.initialization_depth + 1; i <= max_level; ++i) { - spd::debug("hms.at({}).size = {}, must be {}", i, hms.at(i).size(), level_cardinality); - assert(static_cast(hms.at(i).size()) == level_cardinality); - level_cardinality *= 4; - } - - std::map, Real> hm_x_flat, hm_x_steep, hm_y_flat, hm_y_steep; - - for(const auto& dual_point_value_pair : hms.at(max_level)) { - const DualPoint& k = dual_point_value_pair.first; - spd::debug("HM DP: {}", k); - mu_vals.insert(k.mu()); - lambda_vals.insert(k.lambda()); - } - - std::vector lambda_vals_vec(lambda_vals.begin(), lambda_vals.end()); - std::vector mu_vals_vec(mu_vals.begin(), mu_vals.end()); - - std::ofstream ofs {fname}; - if (not ofs.good()) { - std::cerr << "Cannot write heat map to file " << fname << std::endl; - throw std::runtime_error("Cannot open file for writing heat map"); - } - - std::vector> heatmap_to_print(2 * mu_vals_vec.size(), - std::vector(2 * lambda_vals_vec.size(), 0.0)); - - for(auto axis_type : {AxisType::x_type, AxisType::y_type}) { - bool is_x_type = axis_type == AxisType::x_type; - for(auto angle_type : {AngleType::flat, AngleType::steep}) { - bool is_flat = angle_type == AngleType::flat; - - int mu_idx_begin, mu_idx_end; - - if (is_x_type) { - mu_idx_begin = mu_vals_vec.size() - 1; - mu_idx_end = -1; - } else { - mu_idx_begin = 0; - mu_idx_end = mu_vals_vec.size(); - } - - int lambda_idx_begin, lambda_idx_end; - - if (is_flat) { - lambda_idx_begin = 0; - lambda_idx_end = lambda_vals_vec.size(); - } else { - lambda_idx_begin = lambda_vals_vec.size() - 1; - lambda_idx_end = -1; - } - - int mu_idx_final = is_x_type ? 0 : mu_vals_vec.size(); - - for(int mu_idx = mu_idx_begin; mu_idx != mu_idx_end; (mu_idx_begin < mu_idx_end) ? mu_idx++ : mu_idx--) { - Real mu = mu_vals_vec.at(mu_idx); - - if (mu == 0.0 and axis_type == AxisType::x_type) - continue; - - int lambda_idx_final = is_flat ? 0 : lambda_vals_vec.size(); - - for(int lambda_idx = lambda_idx_begin; - lambda_idx != lambda_idx_end; - (lambda_idx_begin < lambda_idx_end) ? lambda_idx++ : lambda_idx--) { - - Real lambda = lambda_vals_vec.at(lambda_idx); - - if (lambda == 0.0 and angle_type == AngleType::flat) - continue; - - DualPoint dp(axis_type, angle_type, lambda, mu); - Real dist_value = hms.at(max_level).at(dp); - - heatmap_to_print.at(mu_idx_final).at(lambda_idx_final) = dist_value; - -// fmt::print("HM, dp = {}, mu_idx_final = {}, lambda_idx_final = {}, value = {}\n", dp, mu_idx_final, -// lambda_idx_final, dist_value); - - lambda_idx_final++; - } - mu_idx_final++; - } - } - } - - for(size_t m_idx = 0; m_idx < heatmap_to_print.size(); ++m_idx) { - for(size_t l_idx = 0; l_idx < heatmap_to_print.at(m_idx).size(); ++l_idx) { - ofs << heatmap_to_print.at(m_idx).at(l_idx) << " "; - } - ofs << std::endl; - } - - ofs.close(); - spd::debug("Exit print_heat_map"); -} -#endif - -int main(int argc, char** argv) -{ - spdlog::set_level(spdlog::level::info); - - using opts::Option; - using opts::PosOption; - opts::Options ops; - - bool help = false; - bool no_stop_asap = false; - CalculationParams params; - -#ifdef PRINT_HEAT_MAP - bool heatmap_only = false; -#endif - - std::string bounds_list_str = "local_combined"; - std::string traverse_list_str = "BFS"; - - ops >> Option('m', "max-iterations", params.max_depth, "maximal number of iterations (refinements)") - >> Option('e', "max-error", params.delta, "error threshold") - >> Option('d', "dim", params.dim, "dim") - >> Option('i', "initial-depth", params.initialization_depth, "initialization depth") - >> Option("no-stop-asap", no_stop_asap, - "don't stop looping over points, if cell cannot be pruned (asap is on by default)") - >> Option("bounds", bounds_list_str, "bounds to use, separated by ,") - >> Option("traverse", traverse_list_str, "traverse to use, separated by ,") -#ifdef PRINT_HEAT_MAP - >> Option("heatmap-only", heatmap_only, "only save heatmap (bruteforce)") -#endif - >> Option('h', "help", help, "show help message"); - - std::string fname_a; - std::string fname_b; - - if (!ops.parse(argc, argv) || help || !(ops >> PosOption(fname_a) >> PosOption(fname_b))) { - std::cerr << "Usage: " << argv[0] << "bifiltration-file-1 bifiltration-file-2\n" << ops << std::endl; - return 1; - } - - params.stop_asap = not no_stop_asap; - - auto bounds_list = split_by_delim(bounds_list_str, ','); - auto traverse_list = split_by_delim(traverse_list_str, ','); - - Bifiltration bif_a(fname_a); - Bifiltration bif_b(fname_b); - - bif_a.sanity_check(); - bif_b.sanity_check(); - - spd::debug("Read bifiltrations {} {}", fname_a, fname_b); - - std::vector bound_strategies; - std::vector traverse_strategies; - - for(std::string s : bounds_list) { - bound_strategies.push_back(bs_from_string(s)); - } - - for(std::string s : traverse_list) { - traverse_strategies.push_back(ts_from_string(s)); - } - - -#ifdef EXPERIMENTAL_TIMING - - for(auto bs : bound_strategies) { - for(auto ts : traverse_strategies) { - spd::info("Will test combination {} {}", bs, ts); - } - } - - struct ExperimentResult { - CalculationParams params {CalculationParams()}; - int n_hera_calls {0}; - double total_milliseconds_elapsed {0}; - Real distance {0}; - Real actual_error {std::numeric_limits::max()}; - int actual_max_depth {0}; - - int x_wins {0}; - int y_wins {0}; - int ad_wins {0}; - - int seconds_elapsed() const - { - return static_cast(total_milliseconds_elapsed / 1000); - } - - double savings_ratio_old() const - { - long int max_possible_calls = 0; - long int calls_on_level = 4; - for(int i = 0; i <= actual_max_depth; ++i) { - max_possible_calls += calls_on_level; - calls_on_level *= 4; - } - return static_cast(n_hera_calls) / static_cast(max_possible_calls); - } - - double savings_ratio() const - { - return static_cast(n_hera_calls) / calls_on_actual_max_depth(); - } - - long long int calls_on_actual_max_depth() const - { - long long int result = 1; - for(int i = 0; i < actual_max_depth; ++i) { - result *= 4; - } - return result; - } - - ExperimentResult() { } - - ExperimentResult(CalculationParams p, int nhc, double tme, double d) - : - params(p), n_hera_calls(nhc), total_milliseconds_elapsed(tme), distance(d) { } - }; - - const int n_repetitions = 1; - - if (heatmap_only) { - bound_strategies.clear(); - bound_strategies.push_back(BoundStrategy::bruteforce); - traverse_strategies.clear(); - traverse_strategies.push_back(TraverseStrategy::breadth_first); - } - - std::map, ExperimentResult> results; - for(BoundStrategy bound_strategy : bound_strategies) { - for(TraverseStrategy traverse_strategy : traverse_strategies) { - CalculationParams params_experiment; - params_experiment.bound_strategy = bound_strategy; - params_experiment.traverse_strategy = traverse_strategy; - params_experiment.max_depth = params.max_depth; - params_experiment.initialization_depth = params.initialization_depth; - params_experiment.delta = params.delta; - params_experiment.dim = params.dim; - params_experiment.hera_epsilon = params.hera_epsilon; - params_experiment.stop_asap = params.stop_asap; - - if (traverse_strategy == TraverseStrategy::depth_first and bound_strategy == BoundStrategy::bruteforce) - continue; - - // if bruteforce, clamp max iterations number to 7, - // save user-provided max_iters in user_max_iters and restore it later. - // remember: params is passed by reference to return real relative error and heat map - - int user_max_iters = params.max_depth; - if (bound_strategy == BoundStrategy::bruteforce and not heatmap_only) { - params_experiment.max_depth = std::min(7, params.max_depth); - } - double total_milliseconds_elapsed = 0; - int total_n_hera_calls = 0; - Real dist; - for(int i = 0; i < n_repetitions; ++i) { - spd::debug("Processing bound_strategy {}, traverse_strategy {}, iteration = {}", bound_strategy, - traverse_strategy, i); - auto t1 = std::chrono::high_resolution_clock().now(); - dist = matching_distance(bif_a, bif_b, params_experiment); - auto t2 = std::chrono::high_resolution_clock().now(); - total_milliseconds_elapsed += std::chrono::duration_cast( - t2 - t1).count(); - total_n_hera_calls += params_experiment.n_hera_calls; - } - - auto key = std::make_tuple(bound_strategy, traverse_strategy); - results[key].params = params_experiment; - results[key].n_hera_calls = total_n_hera_calls / n_repetitions; - results[key].total_milliseconds_elapsed = total_milliseconds_elapsed / n_repetitions; - results[key].distance = dist; - results[key].actual_error = params_experiment.actual_error; - results[key].actual_max_depth = params_experiment.actual_max_depth; - - spd::info( - "Done (bound = {}, traverse = {}), n_hera_calls = {}, time = {} sec, d = {}, error = {}, savings = {}, max_depth = {}", - bound_strategy, traverse_strategy, results[key].n_hera_calls, results[key].seconds_elapsed(), - dist, - params_experiment.actual_error, results[key].savings_ratio(), results[key].actual_max_depth); - - if (bound_strategy == BoundStrategy::bruteforce) { params_experiment.max_depth = user_max_iters; } - -#ifdef PRINT_HEAT_MAP - if (bound_strategy == BoundStrategy::bruteforce) { - fs::path fname_a_path {fname_a.c_str()}; - fs::path fname_b_path {fname_b.c_str()}; - fs::path fname_a_wo = fname_a_path.filename(); - fs::path fname_b_wo = fname_b_path.filename(); - std::string heat_map_fname = fmt::format("{0}_{1}_dim_{2}_weighted_values_xyp.txt", fname_a_wo.string(), - fname_b_wo.string(), params_experiment.dim); - fs::path path_hm = fname_a_path.replace_filename(fs::path(heat_map_fname.c_str())); - spd::debug("Saving heatmap to {}", heat_map_fname); - print_heat_map(params_experiment.heat_maps, path_hm.string(), params); - } -#endif - spd::debug("Finished processing bound_strategy {}", bound_strategy); - } - } - -// std::cout << "File_1;File_2;Boundstrategy;TraverseStrategy;InitalDepth;NHeraCalls;SavingsRatio;Time;Distance;Error;PushStrategy;MaxDepth;CallsOnMaxDepth;Delta;Dimension" << std::endl; - for(auto bs : bound_strategies) { - for(auto ts : traverse_strategies) { - auto key = std::make_tuple(bs, ts); - - fs::path fname_a_path {fname_a.c_str()}; - fs::path fname_b_path {fname_b.c_str()}; - fs::path fname_a_wo = fname_a_path.filename(); - fs::path fname_b_wo = fname_b_path.filename(); - - std::cout << fname_a_wo.string() << ";" << fname_b_wo.string() << ";" << bs << ";" << ts << ";"; - std::cout << results[key].params.initialization_depth << ";"; - std::cout << results[key].n_hera_calls << ";" - << results[key].savings_ratio() << ";" - << results[key].total_milliseconds_elapsed << ";" - << results[key].distance << ";" - << results[key].actual_error << ";" - << "xyp" << ";" - << results[key].actual_max_depth << ";" - << results[key].calls_on_actual_max_depth() << ";" - << params.delta << ";" - << params.dim - << std::endl; - } - } -#else - params.bound_strategy = bound_strategies.back(); - params.traverse_strategy = traverse_strategies.back(); - - spd::debug("Will use {} bound, {} traverse strategy", params.bound_strategy, params.traverse_strategy); - - Real dist = matching_distance(bif_a, bif_b, params); - std::cout << dist << std::endl; -#endif - force_instantiation(); - return 0; -} diff --git a/matching/src/test_generator.cpp b/matching/src/test_generator.cpp deleted file mode 100644 index a2f0625..0000000 --- a/matching/src/test_generator.cpp +++ /dev/null @@ -1,211 +0,0 @@ -#include -#include -#include -#include -#include - -#include "opts/opts.h" -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" - -#include "common_util.h" -#include "bifiltration.h" - -using Real = double; -using Index = md::Index; -using Point = md::Point; -using Bifiltration = md::Bifiltration; -using Column = md::Column; -using Simplex = md::Simplex; - -int g_max_coord = 100; - -using ASimplex = md::AbstractSimplex; - -using ASimplexToBirthMap = std::map; - -namespace spd = spdlog; - -// random generator is global -std::random_device rd; -std::mt19937_64 gen(rd()); - -//std::mt19937_64 gen(42); - -Point get_random_position(int max_coord) -{ - assert(max_coord > 0); - std::uniform_int_distribution distr(0, max_coord); - return Point(distr(gen), distr(gen)); - -} - -Point get_random_position_less_than(Point ub) -{ - std::uniform_int_distribution distr_x(0, ub.x); - std::uniform_int_distribution distr_y(0, ub.y); - return Point(distr_x(gen), distr_y(gen)); -} - -Point get_random_position_greater_than(Point lb) -{ - std::uniform_int_distribution distr_x(lb.x, g_max_coord); - std::uniform_int_distribution distr_y(lb.y, g_max_coord); - return Point(distr_x(gen), distr_y(gen)); -} - -// non-proper faces (empty and simplex itself) are also faces -bool is_face(const ASimplex& face_candidate, const ASimplex& coface_candidate) -{ - return std::includes(coface_candidate.begin(), coface_candidate.end(), face_candidate.begin(), - face_candidate.end()); -} - -bool is_top_simplex(const ASimplex& candidate_simplex, const std::vector& current_top_simplices) -{ - return std::none_of(current_top_simplices.begin(), current_top_simplices.end(), - [&candidate_simplex](const ASimplex& ts) { return is_face(candidate_simplex, ts); }); -} - -void add_if_top(const ASimplex& candidate_simplex, std::vector& current_top_simplices) -{ - // check that candidate simplex is not face of someone in current_top_simplices - if (!is_top_simplex(candidate_simplex, current_top_simplices)) - return; - - spd::debug("candidate_simplex is top, will be added to top_simplices"); - // remove s from currrent_top_simplices, if s is face of candidate_simplex - current_top_simplices.erase(std::remove_if(current_top_simplices.begin(), current_top_simplices.end(), - [candidate_simplex](const ASimplex& s) { return is_face(s, candidate_simplex); }), - current_top_simplices.end()); - - current_top_simplices.push_back(candidate_simplex); -} - -ASimplex get_random_simplex(int n_vertices, int dim) -{ - std::vector all_vertices(n_vertices, 0); - // fill in with 0..n-1 - std::iota(all_vertices.begin(), all_vertices.end(), 0); - std::shuffle(all_vertices.begin(), all_vertices.end(), gen); - return ASimplex(all_vertices.begin(), all_vertices.begin() + dim + 1, true); -} - -void generate_positions(const ASimplex& s, ASimplexToBirthMap& simplex_to_birth, Point upper_bound) -{ - auto pos = get_random_position_less_than(upper_bound); - auto curr_pos_iter = simplex_to_birth.find(s); - if (curr_pos_iter != simplex_to_birth.end()) - pos = md::greatest_lower_bound(pos, curr_pos_iter->second); - simplex_to_birth[s] = pos; - for(const ASimplex& facet : s.facets()) { - generate_positions(facet, simplex_to_birth, pos); - } -} - -Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_simplices) -{ - ASimplexToBirthMap simplex_to_birth; - - // generate vertices - for(int i = 0; i < n_vertices; ++i) { - Point vertex_birth = get_random_position(g_max_coord / 10); - ASimplex vertex; - vertex.push_back(i); - simplex_to_birth[vertex] = vertex_birth; - } - - std::vector top_simplices; - // generate top simplices - while((int)top_simplices.size() < n_top_simplices) { - std::uniform_int_distribution dimension_distr(1, max_dim); - int dim = dimension_distr(gen); - auto candidate_simplex = get_random_simplex(n_vertices, dim); - spd::debug("candidate_simplex = {}", candidate_simplex); - add_if_top(candidate_simplex, top_simplices); - } - - Point upper_bound{static_cast(g_max_coord), static_cast(g_max_coord)}; - for(const auto& top_simplex : top_simplices) { - generate_positions(top_simplex, simplex_to_birth, upper_bound); - } - - std::vector> simplex_birth_pairs{simplex_to_birth.begin(), simplex_to_birth.end()}; - std::vector boundaries{simplex_to_birth.size(), Column()}; - -// assign ids and save boundaries - int id = 0; - - for(int dim = 0; dim <= max_dim; ++dim) { - for(int i = 0; i < (int) simplex_birth_pairs.size(); ++i) { - ASimplex& simplex = simplex_birth_pairs[i].first; - if (simplex.dim() == dim) { - simplex.id = id++; - Column bdry; - for(auto& facet : simplex.facets()) { - auto facet_iter = std::find_if(simplex_birth_pairs.begin(), simplex_birth_pairs.end(), - [facet](const std::pair& sbp) { return facet == sbp.first; }); - assert(facet_iter != simplex_birth_pairs.end()); - assert(facet_iter->first.id >= 0); - bdry.push_back(facet_iter->first.id); - } - std::sort(bdry.begin(), bdry.end()); - boundaries[i] = bdry; - } - } - } - -// create vector of Simplex-es - std::vector simplices; - for(int i = 0; i < (int) simplex_birth_pairs.size(); ++i) { - int id = simplex_birth_pairs[i].first.id; - int dim = simplex_birth_pairs[i].first.dim(); - Point birth = simplex_birth_pairs[i].second; - Column bdry = boundaries[i]; - simplices.emplace_back(id, birth, dim, bdry); - } - -// sort by id - std::sort(simplices.begin(), simplices.end(), - [](const Simplex& s1, const Simplex& s2) { return s1.id() < s2.id(); }); - for(int i = 0; i < (int)simplices.size(); ++i) { - assert(simplices[i].id() == i); - assert(i == 0 || simplices[i].dim() >= simplices[i - 1].dim()); - } - - return Bifiltration(simplices.begin(), simplices.end()); -} - -int main(int argc, char** argv) -{ - spd::set_level(spd::level::info); - int n_vertices; - int max_dim; - int n_top_simplices; - std::string fname; - - using opts::Option; - using opts::PosOption; - opts::Options ops; - - bool help = false; - - ops >> Option('v', "n-vertices", n_vertices, "number of vertices") - >> Option('d', "max-dim", max_dim, "maximal dim") - >> Option('m', "max-coord", g_max_coord, "maximal coordinate") - >> Option('t', "n-top-simplices", n_top_simplices, "number of top simplices") - >> Option('h', "help", help, "show help message"); - - if (!ops.parse(argc, argv) || help || !(ops >> PosOption(fname))) { - std::cerr << "Usage: " << argv[0] << "\n" << ops << std::endl; - return 1; - } - - - auto bif1 = get_random_bifiltration(n_vertices, max_dim, n_top_simplices); - std::cout << "Generated bifiltration." << std::endl; - bif1.save(fname, md::BifiltrationFormat::phat_like); - std::cout << "Saved to file " << fname << std::endl; - return 0; -} - diff --git a/matching/src/tests/prism_1.bif b/matching/src/tests/prism_1.bif deleted file mode 100644 index b37e807..0000000 --- a/matching/src/tests/prism_1.bif +++ /dev/null @@ -1,28 +0,0 @@ -bifiltration_phat_like -26 -0 0 0 -0 0 0 -0 0 0 -0 0 0 -0 0 0 -0 0 0 -1 1 0 1 0 -1 1 0 2 1 -1 0 0 3 1 -1 0 1 5 4 -1 0 0 4 1 -1 0 0 3 0 -1 0 0 5 0 -1 0 0 4 2 -1 0 0 5 2 -1 0 1 4 3 -1 1 0 2 0 -1 0 1 5 3 -2 1 1 13 10 7 -2 1 1 9 14 13 -2 1 1 8 11 6 -2 1 1 15 10 8 -2 1 1 14 12 16 -2 1 1 17 12 11 -2 4 0 7 16 6 -2 0 4 9 17 15 diff --git a/matching/src/tests/prism_2.bif b/matching/src/tests/prism_2.bif deleted file mode 100644 index 49885e6..0000000 --- a/matching/src/tests/prism_2.bif +++ /dev/null @@ -1,29 +0,0 @@ -bifiltration_phat_like -27 -0 0 0 -0 0 0 -0 0 0 -0 0 0 -0 0 0 -0 0 0 -1 0 0 2 1 -1 0 0 1 0 -1 0 0 5 3 -1 0 0 3 1 -1 0 0 5 4 -1 0 0 2 0 -1 0 0 4 1 -1 0 0 3 2 -1 0 0 5 2 -1 0 0 4 3 -1 0 0 4 2 -1 0 0 3 0 -2 0 0 16 12 6 -2 0 0 10 14 16 -2 0 0 9 17 7 -2 0 0 15 12 9 -2 0 0 13 17 11 -2 0 0 8 14 13 -2 4 0 6 11 7 -2 0 4 10 8 15 -2 3 3 15 16 13 diff --git a/matching/src/tests/test_bifiltration.cpp b/matching/src/tests/test_bifiltration.cpp deleted file mode 100644 index 742dab8..0000000 --- a/matching/src/tests/test_bifiltration.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include "catch/catch.hpp" - -#include -#include - -#include "common_util.h" -#include "box.h" -#include "bifiltration.h" - -using namespace md; - -//TEST_CASE("Small check", "[bifiltration][dim2]") -//{ -// Bifiltration bif("/home/narn/code/matching_distance/code/src/tests/test_bifiltration_full_triangle_phat_like.txt", BifiltrationFormat::phat_like); -// auto simplices = bif.simplices(); -// bif.sanity_check(); -// -// REQUIRE( simplices.size() == 7 ); -// -// REQUIRE( simplices[0].dim() == 0 ); -// REQUIRE( simplices[1].dim() == 0 ); -// REQUIRE( simplices[2].dim() == 0 ); -// REQUIRE( simplices[3].dim() == 1 ); -// REQUIRE( simplices[4].dim() == 1 ); -// REQUIRE( simplices[5].dim() == 1 ); -// REQUIRE( simplices[6].dim() == 2); -// -// REQUIRE( simplices[0].position() == Point(0, 0)); -// REQUIRE( simplices[1].position() == Point(0, 0)); -// REQUIRE( simplices[2].position() == Point(0, 0)); -// REQUIRE( simplices[3].position() == Point(3, 1)); -// REQUIRE( simplices[6].position() == Point(30, 40)); -// -// Line line_1(Line::pi / 2.0, 0.0); -// auto dgm = bif.slice_diagram(line_1); -//} diff --git a/matching/src/tests/test_bifiltration_1.txt b/matching/src/tests/test_bifiltration_1.txt deleted file mode 120000 index ddd23e9..0000000 --- a/matching/src/tests/test_bifiltration_1.txt +++ /dev/null @@ -1 +0,0 @@ -../../data/test_bifiltration_1.txt \ No newline at end of file diff --git a/matching/src/tests/test_bifiltration_full_triangle_rene.txt b/matching/src/tests/test_bifiltration_full_triangle_rene.txt deleted file mode 120000 index 47f49fd..0000000 --- a/matching/src/tests/test_bifiltration_full_triangle_rene.txt +++ /dev/null @@ -1 +0,0 @@ -../../data/test_bifiltration_full_triangle_rene.txt \ No newline at end of file diff --git a/matching/src/tests/test_common.cpp b/matching/src/tests/test_common.cpp deleted file mode 100644 index 9079a56..0000000 --- a/matching/src/tests/test_common.cpp +++ /dev/null @@ -1,156 +0,0 @@ -#include "catch/catch.hpp" - -#include -#include -#include - -#include "common_util.h" -#include "simplex.h" -#include "matching_distance.h" - -//using namespace md; -using Real = double; -using Point = md::Point; -using Bifiltration = md::Bifiltration; -using BifiltrationProxy = md::BifiltrationProxy; -using CalculationParams = md::CalculationParams; -using CellWithValue = md::CellWithValue; -using DualPoint = md::DualPoint; -using DualBox = md::DualBox; -using Simplex = md::Simplex; -using AbstractSimplex = md::AbstractSimplex; -using BoundStrategy = md::BoundStrategy; -using TraverseStrategy = md::TraverseStrategy; -using AxisType = md::AxisType; -using AngleType = md::AngleType; -using ValuePoint = md::ValuePoint; -using Column = md::Column; - - -TEST_CASE("AbstractSimplex", "[abstract_simplex]") -{ - AbstractSimplex as; - REQUIRE(as.dim() == -1); - - as.push_back(1); - REQUIRE(as.dim() == 0); - REQUIRE(as.facets().size() == 0); - - as.push_back(0); - REQUIRE(as.dim() == 1); - REQUIRE(as.facets().size() == 2); - REQUIRE(as.facets()[0].dim() == 0); - REQUIRE(as.facets()[1].dim() == 0); - -} - -TEST_CASE("Vertical line", "[vertical_line]") -{ - // line x = 1 - DualPoint l_vertical(AxisType::x_type, AngleType::steep, 0, 1); - - REQUIRE(l_vertical.is_vertical()); - REQUIRE(l_vertical.is_steep()); - - Point p_1(0.5, 0.5); - Point p_2(1.5, 0.5); - Point p_3(1.5, 1.5); - Point p_4(0.5, 1.5); - Point p_5(1, 10); - - REQUIRE(l_vertical.x_from_y(10) == 1); - REQUIRE(l_vertical.x_from_y(-10) == 1); - REQUIRE(l_vertical.x_from_y(0) == 1); - - REQUIRE(not l_vertical.contains(p_1)); - REQUIRE(not l_vertical.contains(p_2)); - REQUIRE(not l_vertical.contains(p_3)); - REQUIRE(not l_vertical.contains(p_4)); - REQUIRE(l_vertical.contains(p_5)); - - REQUIRE(l_vertical.goes_below(p_1)); - REQUIRE(not l_vertical.goes_below(p_2)); - REQUIRE(not l_vertical.goes_below(p_3)); - REQUIRE(l_vertical.goes_below(p_4)); - - REQUIRE(not l_vertical.goes_above(p_1)); - REQUIRE(l_vertical.goes_above(p_2)); - REQUIRE(l_vertical.goes_above(p_3)); - REQUIRE(not l_vertical.goes_above(p_4)); - -} - -TEST_CASE("Horizontal line", "[horizontal_line]") -{ - // line y = 1 - DualPoint l_horizontal(AxisType::y_type, AngleType::flat, 0, 1); - - REQUIRE(l_horizontal.is_horizontal()); - REQUIRE(l_horizontal.is_flat()); - REQUIRE(l_horizontal.y_slope() == 0); - REQUIRE(l_horizontal.y_intercept() == 1); - - Point p_1(0.5, 0.5); - Point p_2(1.5, 0.5); - Point p_3(1.5, 1.5); - Point p_4(0.5, 1.5); - Point p_5(2, 1); - - REQUIRE((not l_horizontal.contains(p_1) and - not l_horizontal.contains(p_2) and - not l_horizontal.contains(p_3) and - not l_horizontal.contains(p_4) and - l_horizontal.contains(p_5))); - - REQUIRE(not l_horizontal.goes_below(p_1)); - REQUIRE(not l_horizontal.goes_below(p_2)); - REQUIRE(l_horizontal.goes_below(p_3)); - REQUIRE(l_horizontal.goes_below(p_4)); - REQUIRE(l_horizontal.goes_below(p_5)); - - REQUIRE(l_horizontal.goes_above(p_1)); - REQUIRE(l_horizontal.goes_above(p_2)); - REQUIRE(not l_horizontal.goes_above(p_3)); - REQUIRE(not l_horizontal.goes_above(p_4)); - REQUIRE(l_horizontal.goes_above(p_5)); -} - -TEST_CASE("Flat Line with positive slope", "[flat_line]") -{ - // line y = x / 2 + 1 - DualPoint l_flat(AxisType::y_type, AngleType::flat, 0.5, 1); - - REQUIRE(not l_flat.is_horizontal()); - REQUIRE(l_flat.is_flat()); - REQUIRE(l_flat.y_slope() == 0.5); - REQUIRE(l_flat.y_intercept() == 1); - - REQUIRE(l_flat.y_from_x(0) == 1); - REQUIRE(l_flat.y_from_x(1) == 1.5); - REQUIRE(l_flat.y_from_x(2) == 2); - - Point p_1(3, 2); - Point p_2(-2, 0.01); - Point p_3(0, 1.25); - Point p_4(6, 4.5); - Point p_5(2, 2); - - REQUIRE((not l_flat.contains(p_1) and - not l_flat.contains(p_2) and - not l_flat.contains(p_3) and - not l_flat.contains(p_4) and - l_flat.contains(p_5))); - - REQUIRE(not l_flat.goes_below(p_1)); - REQUIRE(l_flat.goes_below(p_2)); - REQUIRE(l_flat.goes_below(p_3)); - REQUIRE(l_flat.goes_below(p_4)); - REQUIRE(l_flat.goes_below(p_5)); - - REQUIRE(l_flat.goes_above(p_1)); - REQUIRE(not l_flat.goes_above(p_2)); - REQUIRE(not l_flat.goes_above(p_3)); - REQUIRE(not l_flat.goes_above(p_4)); - REQUIRE(l_flat.goes_above(p_5)); - -} diff --git a/matching/src/tests/test_list.txt b/matching/src/tests/test_list.txt deleted file mode 100644 index 1984606..0000000 --- a/matching/src/tests/test_list.txt +++ /dev/null @@ -1 +0,0 @@ -prism_lesnick_1.bif prism_lesnick_2.bif 1.0 diff --git a/matching/src/tests/test_matching_distance.cpp b/matching/src/tests/test_matching_distance.cpp deleted file mode 100644 index 82da530..0000000 --- a/matching/src/tests/test_matching_distance.cpp +++ /dev/null @@ -1,167 +0,0 @@ -#include "catch/catch.hpp" - -#include -#include -#include - -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" - -#include "common_util.h" -#include "simplex.h" -#include "matching_distance.h" - -using Real = double; -using Point = md::Point; -using Bifiltration = md::Bifiltration; -using BifiltrationProxy = md::BifiltrationProxy; -using CalculationParams = md::CalculationParams; -using CellWithValue = md::CellWithValue; -using DualPoint = md::DualPoint; -using DualBox = md::DualBox; -using Simplex = md::Simplex; -using AbstractSimplex = md::AbstractSimplex; -using BoundStrategy = md::BoundStrategy; -using TraverseStrategy = md::TraverseStrategy; -using AxisType = md::AxisType; -using AngleType = md::AngleType; -using ValuePoint = md::ValuePoint; -using Column = md::Column; - -using md::k_corner_vps; - -namespace spd = spdlog; - -TEST_CASE("Different bounds", "[bounds]") -{ - std::vector simplices; - std::vector points; - - Real max_x = 10; - Real max_y = 20; - - int simplex_id = 0; - for(int i = 0; i <= max_x; ++i) { - for(int j = 0; j <= max_y; ++j) { - Point p(i, j); - simplices.emplace_back(simplex_id++, p, 0, Column()); - points.push_back(p); - } - } - - Bifiltration bif_a(simplices.begin(), simplices.end()); - Bifiltration bif_b(simplices.begin(), simplices.end()); - - CalculationParams params; - params.initialization_depth = 2; - - BifiltrationProxy bifp_a(bif_a, params.dim); - BifiltrationProxy bifp_b(bif_b, params.dim); - - md::DistanceCalculator calc(bifp_a, bifp_b, params); - -// REQUIRE(calc.max_x_ == Approx(max_x)); -// REQUIRE(calc.max_y_ == Approx(max_y)); - - std::vector boxes; - - for(CellWithValue c : calc.get_refined_grid(5, false, false)) { - boxes.push_back(c.dual_box()); - } - - // fill in boxes and points - - for(DualBox db : boxes) { - Real local_bound = calc.get_local_dual_bound(db); - Real local_bound_refined = calc.get_local_refined_bound(db); - REQUIRE(local_bound >= local_bound_refined); - for(Point p : points) { - for(ValuePoint vp_a : k_corner_vps) { - CellWithValue dual_cell(db, 1); - DualPoint corner_a = dual_cell.value_point(vp_a); - Real wp_a = corner_a.weighted_push(p); - dual_cell.set_value_at(vp_a, wp_a); - Real point_bound = calc.get_max_displacement_single_point(dual_cell, vp_a, p); - for(ValuePoint vp_b : k_corner_vps) { - if (vp_b <= vp_a) - continue; - DualPoint corner_b = dual_cell.value_point(vp_b); - Real wp_b = corner_b.weighted_push(p); - Real diff = fabs(wp_a - wp_b); - if (not(point_bound <= Approx(local_bound_refined))) { - std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound - << ", refined local = " << local_bound_refined << std::endl; - spd::set_level(spd::level::debug); - calc.get_max_displacement_single_point(dual_cell, vp_a, p); - calc.get_local_refined_bound(db); - spd::set_level(spd::level::info); - } - - REQUIRE(point_bound <= Approx(local_bound_refined)); - REQUIRE(diff <= Approx(point_bound)); - REQUIRE(diff <= Approx(local_bound_refined)); - } - - for(DualPoint l_random : db.random_points(100)) { - Real wp_random = l_random.weighted_push(p); - Real diff = fabs(wp_a - wp_random); - if (not(diff <= Approx(point_bound))) { - if (db.critical_points(p).size() > 4) { - std::cerr << "ERROR interesting case" << std::endl; - } else { - std::cerr << "ERROR boring case" << std::endl; - } - spd::set_level(spd::level::debug); - l_random.weighted_push(p); - spd::set_level(spd::level::info); - std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound - << ", refined local = " << local_bound_refined; - std::cerr << ", random_dual_point = " << l_random << ", wp_random = " << wp_random - << ", diff = " << diff << std::endl; - } - REQUIRE(diff <= Approx(point_bound)); - } - } - } - } -} - -TEST_CASE("Bifiltrations from file", "[matching_distance][small_example][lesnick]") -{ - std::string fname_a, fname_b; - - fname_a = "../src/tests/prism_1.bif"; - fname_b = "../src/tests/prism_2.bif"; - - Bifiltration bif_a(fname_a); - Bifiltration bif_b(fname_b); - - CalculationParams params; - - std::vector bound_strategies {BoundStrategy::local_combined, - BoundStrategy::local_dual_bound_refined}; - - std::vector traverse_strategies {TraverseStrategy::breadth_first, TraverseStrategy::depth_first}; - - std::vector scaling_factors {10, 1.0}; - - for(auto bs : bound_strategies) { - for(auto ts : traverse_strategies) { - for(double lambda : scaling_factors) { - Bifiltration bif_a_copy(bif_a); - Bifiltration bif_b_copy(bif_b); - bif_a_copy.scale(lambda); - bif_b_copy.scale(lambda); - params.bound_strategy = bs; - params.traverse_strategy = ts; - params.max_depth = 7; - params.delta = 0.1; - params.dim = 1; - Real answer = matching_distance(bif_a_copy, bif_b_copy, params); - Real correct_answer = lambda * 1.0; - REQUIRE(fabs(answer - correct_answer) / correct_answer < params.delta); - } - } - } -} - diff --git a/matching/src/tests/tests_main.cpp b/matching/src/tests/tests_main.cpp deleted file mode 100644 index 1c77b13..0000000 --- a/matching/src/tests/tests_main.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#define CATCH_CONFIG_MAIN -#include "catch/catch.hpp" diff --git a/matching/tests/prism_1.bif b/matching/tests/prism_1.bif new file mode 100644 index 0000000..b37e807 --- /dev/null +++ b/matching/tests/prism_1.bif @@ -0,0 +1,28 @@ +bifiltration_phat_like +26 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +1 1 0 1 0 +1 1 0 2 1 +1 0 0 3 1 +1 0 1 5 4 +1 0 0 4 1 +1 0 0 3 0 +1 0 0 5 0 +1 0 0 4 2 +1 0 0 5 2 +1 0 1 4 3 +1 1 0 2 0 +1 0 1 5 3 +2 1 1 13 10 7 +2 1 1 9 14 13 +2 1 1 8 11 6 +2 1 1 15 10 8 +2 1 1 14 12 16 +2 1 1 17 12 11 +2 4 0 7 16 6 +2 0 4 9 17 15 diff --git a/matching/tests/prism_2.bif b/matching/tests/prism_2.bif new file mode 100644 index 0000000..49885e6 --- /dev/null +++ b/matching/tests/prism_2.bif @@ -0,0 +1,29 @@ +bifiltration_phat_like +27 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +1 0 0 2 1 +1 0 0 1 0 +1 0 0 5 3 +1 0 0 3 1 +1 0 0 5 4 +1 0 0 2 0 +1 0 0 4 1 +1 0 0 3 2 +1 0 0 5 2 +1 0 0 4 3 +1 0 0 4 2 +1 0 0 3 0 +2 0 0 16 12 6 +2 0 0 10 14 16 +2 0 0 9 17 7 +2 0 0 15 12 9 +2 0 0 13 17 11 +2 0 0 8 14 13 +2 4 0 6 11 7 +2 0 4 10 8 15 +2 3 3 15 16 13 diff --git a/matching/tests/test_bifiltration.cpp b/matching/tests/test_bifiltration.cpp new file mode 100644 index 0000000..742dab8 --- /dev/null +++ b/matching/tests/test_bifiltration.cpp @@ -0,0 +1,36 @@ +#include "catch/catch.hpp" + +#include +#include + +#include "common_util.h" +#include "box.h" +#include "bifiltration.h" + +using namespace md; + +//TEST_CASE("Small check", "[bifiltration][dim2]") +//{ +// Bifiltration bif("/home/narn/code/matching_distance/code/src/tests/test_bifiltration_full_triangle_phat_like.txt", BifiltrationFormat::phat_like); +// auto simplices = bif.simplices(); +// bif.sanity_check(); +// +// REQUIRE( simplices.size() == 7 ); +// +// REQUIRE( simplices[0].dim() == 0 ); +// REQUIRE( simplices[1].dim() == 0 ); +// REQUIRE( simplices[2].dim() == 0 ); +// REQUIRE( simplices[3].dim() == 1 ); +// REQUIRE( simplices[4].dim() == 1 ); +// REQUIRE( simplices[5].dim() == 1 ); +// REQUIRE( simplices[6].dim() == 2); +// +// REQUIRE( simplices[0].position() == Point(0, 0)); +// REQUIRE( simplices[1].position() == Point(0, 0)); +// REQUIRE( simplices[2].position() == Point(0, 0)); +// REQUIRE( simplices[3].position() == Point(3, 1)); +// REQUIRE( simplices[6].position() == Point(30, 40)); +// +// Line line_1(Line::pi / 2.0, 0.0); +// auto dgm = bif.slice_diagram(line_1); +//} diff --git a/matching/tests/test_bifiltration_1.txt b/matching/tests/test_bifiltration_1.txt new file mode 120000 index 0000000..ddd23e9 --- /dev/null +++ b/matching/tests/test_bifiltration_1.txt @@ -0,0 +1 @@ +../../data/test_bifiltration_1.txt \ No newline at end of file diff --git a/matching/tests/test_bifiltration_full_triangle_rene.txt b/matching/tests/test_bifiltration_full_triangle_rene.txt new file mode 120000 index 0000000..47f49fd --- /dev/null +++ b/matching/tests/test_bifiltration_full_triangle_rene.txt @@ -0,0 +1 @@ +../../data/test_bifiltration_full_triangle_rene.txt \ No newline at end of file diff --git a/matching/tests/test_common.cpp b/matching/tests/test_common.cpp new file mode 100644 index 0000000..9079a56 --- /dev/null +++ b/matching/tests/test_common.cpp @@ -0,0 +1,156 @@ +#include "catch/catch.hpp" + +#include +#include +#include + +#include "common_util.h" +#include "simplex.h" +#include "matching_distance.h" + +//using namespace md; +using Real = double; +using Point = md::Point; +using Bifiltration = md::Bifiltration; +using BifiltrationProxy = md::BifiltrationProxy; +using CalculationParams = md::CalculationParams; +using CellWithValue = md::CellWithValue; +using DualPoint = md::DualPoint; +using DualBox = md::DualBox; +using Simplex = md::Simplex; +using AbstractSimplex = md::AbstractSimplex; +using BoundStrategy = md::BoundStrategy; +using TraverseStrategy = md::TraverseStrategy; +using AxisType = md::AxisType; +using AngleType = md::AngleType; +using ValuePoint = md::ValuePoint; +using Column = md::Column; + + +TEST_CASE("AbstractSimplex", "[abstract_simplex]") +{ + AbstractSimplex as; + REQUIRE(as.dim() == -1); + + as.push_back(1); + REQUIRE(as.dim() == 0); + REQUIRE(as.facets().size() == 0); + + as.push_back(0); + REQUIRE(as.dim() == 1); + REQUIRE(as.facets().size() == 2); + REQUIRE(as.facets()[0].dim() == 0); + REQUIRE(as.facets()[1].dim() == 0); + +} + +TEST_CASE("Vertical line", "[vertical_line]") +{ + // line x = 1 + DualPoint l_vertical(AxisType::x_type, AngleType::steep, 0, 1); + + REQUIRE(l_vertical.is_vertical()); + REQUIRE(l_vertical.is_steep()); + + Point p_1(0.5, 0.5); + Point p_2(1.5, 0.5); + Point p_3(1.5, 1.5); + Point p_4(0.5, 1.5); + Point p_5(1, 10); + + REQUIRE(l_vertical.x_from_y(10) == 1); + REQUIRE(l_vertical.x_from_y(-10) == 1); + REQUIRE(l_vertical.x_from_y(0) == 1); + + REQUIRE(not l_vertical.contains(p_1)); + REQUIRE(not l_vertical.contains(p_2)); + REQUIRE(not l_vertical.contains(p_3)); + REQUIRE(not l_vertical.contains(p_4)); + REQUIRE(l_vertical.contains(p_5)); + + REQUIRE(l_vertical.goes_below(p_1)); + REQUIRE(not l_vertical.goes_below(p_2)); + REQUIRE(not l_vertical.goes_below(p_3)); + REQUIRE(l_vertical.goes_below(p_4)); + + REQUIRE(not l_vertical.goes_above(p_1)); + REQUIRE(l_vertical.goes_above(p_2)); + REQUIRE(l_vertical.goes_above(p_3)); + REQUIRE(not l_vertical.goes_above(p_4)); + +} + +TEST_CASE("Horizontal line", "[horizontal_line]") +{ + // line y = 1 + DualPoint l_horizontal(AxisType::y_type, AngleType::flat, 0, 1); + + REQUIRE(l_horizontal.is_horizontal()); + REQUIRE(l_horizontal.is_flat()); + REQUIRE(l_horizontal.y_slope() == 0); + REQUIRE(l_horizontal.y_intercept() == 1); + + Point p_1(0.5, 0.5); + Point p_2(1.5, 0.5); + Point p_3(1.5, 1.5); + Point p_4(0.5, 1.5); + Point p_5(2, 1); + + REQUIRE((not l_horizontal.contains(p_1) and + not l_horizontal.contains(p_2) and + not l_horizontal.contains(p_3) and + not l_horizontal.contains(p_4) and + l_horizontal.contains(p_5))); + + REQUIRE(not l_horizontal.goes_below(p_1)); + REQUIRE(not l_horizontal.goes_below(p_2)); + REQUIRE(l_horizontal.goes_below(p_3)); + REQUIRE(l_horizontal.goes_below(p_4)); + REQUIRE(l_horizontal.goes_below(p_5)); + + REQUIRE(l_horizontal.goes_above(p_1)); + REQUIRE(l_horizontal.goes_above(p_2)); + REQUIRE(not l_horizontal.goes_above(p_3)); + REQUIRE(not l_horizontal.goes_above(p_4)); + REQUIRE(l_horizontal.goes_above(p_5)); +} + +TEST_CASE("Flat Line with positive slope", "[flat_line]") +{ + // line y = x / 2 + 1 + DualPoint l_flat(AxisType::y_type, AngleType::flat, 0.5, 1); + + REQUIRE(not l_flat.is_horizontal()); + REQUIRE(l_flat.is_flat()); + REQUIRE(l_flat.y_slope() == 0.5); + REQUIRE(l_flat.y_intercept() == 1); + + REQUIRE(l_flat.y_from_x(0) == 1); + REQUIRE(l_flat.y_from_x(1) == 1.5); + REQUIRE(l_flat.y_from_x(2) == 2); + + Point p_1(3, 2); + Point p_2(-2, 0.01); + Point p_3(0, 1.25); + Point p_4(6, 4.5); + Point p_5(2, 2); + + REQUIRE((not l_flat.contains(p_1) and + not l_flat.contains(p_2) and + not l_flat.contains(p_3) and + not l_flat.contains(p_4) and + l_flat.contains(p_5))); + + REQUIRE(not l_flat.goes_below(p_1)); + REQUIRE(l_flat.goes_below(p_2)); + REQUIRE(l_flat.goes_below(p_3)); + REQUIRE(l_flat.goes_below(p_4)); + REQUIRE(l_flat.goes_below(p_5)); + + REQUIRE(l_flat.goes_above(p_1)); + REQUIRE(not l_flat.goes_above(p_2)); + REQUIRE(not l_flat.goes_above(p_3)); + REQUIRE(not l_flat.goes_above(p_4)); + REQUIRE(l_flat.goes_above(p_5)); + +} diff --git a/matching/tests/test_list.txt b/matching/tests/test_list.txt new file mode 100644 index 0000000..1984606 --- /dev/null +++ b/matching/tests/test_list.txt @@ -0,0 +1 @@ +prism_lesnick_1.bif prism_lesnick_2.bif 1.0 diff --git a/matching/tests/test_matching_distance.cpp b/matching/tests/test_matching_distance.cpp new file mode 100644 index 0000000..82da530 --- /dev/null +++ b/matching/tests/test_matching_distance.cpp @@ -0,0 +1,167 @@ +#include "catch/catch.hpp" + +#include +#include +#include + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +#include "common_util.h" +#include "simplex.h" +#include "matching_distance.h" + +using Real = double; +using Point = md::Point; +using Bifiltration = md::Bifiltration; +using BifiltrationProxy = md::BifiltrationProxy; +using CalculationParams = md::CalculationParams; +using CellWithValue = md::CellWithValue; +using DualPoint = md::DualPoint; +using DualBox = md::DualBox; +using Simplex = md::Simplex; +using AbstractSimplex = md::AbstractSimplex; +using BoundStrategy = md::BoundStrategy; +using TraverseStrategy = md::TraverseStrategy; +using AxisType = md::AxisType; +using AngleType = md::AngleType; +using ValuePoint = md::ValuePoint; +using Column = md::Column; + +using md::k_corner_vps; + +namespace spd = spdlog; + +TEST_CASE("Different bounds", "[bounds]") +{ + std::vector simplices; + std::vector points; + + Real max_x = 10; + Real max_y = 20; + + int simplex_id = 0; + for(int i = 0; i <= max_x; ++i) { + for(int j = 0; j <= max_y; ++j) { + Point p(i, j); + simplices.emplace_back(simplex_id++, p, 0, Column()); + points.push_back(p); + } + } + + Bifiltration bif_a(simplices.begin(), simplices.end()); + Bifiltration bif_b(simplices.begin(), simplices.end()); + + CalculationParams params; + params.initialization_depth = 2; + + BifiltrationProxy bifp_a(bif_a, params.dim); + BifiltrationProxy bifp_b(bif_b, params.dim); + + md::DistanceCalculator calc(bifp_a, bifp_b, params); + +// REQUIRE(calc.max_x_ == Approx(max_x)); +// REQUIRE(calc.max_y_ == Approx(max_y)); + + std::vector boxes; + + for(CellWithValue c : calc.get_refined_grid(5, false, false)) { + boxes.push_back(c.dual_box()); + } + + // fill in boxes and points + + for(DualBox db : boxes) { + Real local_bound = calc.get_local_dual_bound(db); + Real local_bound_refined = calc.get_local_refined_bound(db); + REQUIRE(local_bound >= local_bound_refined); + for(Point p : points) { + for(ValuePoint vp_a : k_corner_vps) { + CellWithValue dual_cell(db, 1); + DualPoint corner_a = dual_cell.value_point(vp_a); + Real wp_a = corner_a.weighted_push(p); + dual_cell.set_value_at(vp_a, wp_a); + Real point_bound = calc.get_max_displacement_single_point(dual_cell, vp_a, p); + for(ValuePoint vp_b : k_corner_vps) { + if (vp_b <= vp_a) + continue; + DualPoint corner_b = dual_cell.value_point(vp_b); + Real wp_b = corner_b.weighted_push(p); + Real diff = fabs(wp_a - wp_b); + if (not(point_bound <= Approx(local_bound_refined))) { + std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound + << ", refined local = " << local_bound_refined << std::endl; + spd::set_level(spd::level::debug); + calc.get_max_displacement_single_point(dual_cell, vp_a, p); + calc.get_local_refined_bound(db); + spd::set_level(spd::level::info); + } + + REQUIRE(point_bound <= Approx(local_bound_refined)); + REQUIRE(diff <= Approx(point_bound)); + REQUIRE(diff <= Approx(local_bound_refined)); + } + + for(DualPoint l_random : db.random_points(100)) { + Real wp_random = l_random.weighted_push(p); + Real diff = fabs(wp_a - wp_random); + if (not(diff <= Approx(point_bound))) { + if (db.critical_points(p).size() > 4) { + std::cerr << "ERROR interesting case" << std::endl; + } else { + std::cerr << "ERROR boring case" << std::endl; + } + spd::set_level(spd::level::debug); + l_random.weighted_push(p); + spd::set_level(spd::level::info); + std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound + << ", refined local = " << local_bound_refined; + std::cerr << ", random_dual_point = " << l_random << ", wp_random = " << wp_random + << ", diff = " << diff << std::endl; + } + REQUIRE(diff <= Approx(point_bound)); + } + } + } + } +} + +TEST_CASE("Bifiltrations from file", "[matching_distance][small_example][lesnick]") +{ + std::string fname_a, fname_b; + + fname_a = "../src/tests/prism_1.bif"; + fname_b = "../src/tests/prism_2.bif"; + + Bifiltration bif_a(fname_a); + Bifiltration bif_b(fname_b); + + CalculationParams params; + + std::vector bound_strategies {BoundStrategy::local_combined, + BoundStrategy::local_dual_bound_refined}; + + std::vector traverse_strategies {TraverseStrategy::breadth_first, TraverseStrategy::depth_first}; + + std::vector scaling_factors {10, 1.0}; + + for(auto bs : bound_strategies) { + for(auto ts : traverse_strategies) { + for(double lambda : scaling_factors) { + Bifiltration bif_a_copy(bif_a); + Bifiltration bif_b_copy(bif_b); + bif_a_copy.scale(lambda); + bif_b_copy.scale(lambda); + params.bound_strategy = bs; + params.traverse_strategy = ts; + params.max_depth = 7; + params.delta = 0.1; + params.dim = 1; + Real answer = matching_distance(bif_a_copy, bif_b_copy, params); + Real correct_answer = lambda * 1.0; + REQUIRE(fabs(answer - correct_answer) / correct_answer < params.delta); + } + } + } +} + diff --git a/matching/tests/tests_main.cpp b/matching/tests/tests_main.cpp new file mode 100644 index 0000000..1c77b13 --- /dev/null +++ b/matching/tests/tests_main.cpp @@ -0,0 +1,2 @@ +#define CATCH_CONFIG_MAIN +#include "catch/catch.hpp" -- cgit v1.2.3