From 5ebf7142b00554b3f5d151c8b4e81b746962a5b8 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Fri, 6 Mar 2020 18:29:25 +0100 Subject: Reorganize matching_dist code, minor fixes. --- matching/example/matching_dist.cpp | 396 +++++++++++++++++++++++++++++++++++++ 1 file changed, 396 insertions(+) create mode 100644 matching/example/matching_dist.cpp (limited to 'matching/example') diff --git a/matching/example/matching_dist.cpp b/matching/example/matching_dist.cpp new file mode 100644 index 0000000..13e5c6d --- /dev/null +++ b/matching/example/matching_dist.cpp @@ -0,0 +1,396 @@ +#include "common_defs.h" + +#include +#include +#include +#include + +#ifdef MD_EXPERIMENTAL_TIMING +#include +#endif + +#include "opts/opts.h" +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +//#include "persistence_module.h" +#include "bifiltration.h" +#include "box.h" +#include "matching_distance.h" + +using Real = double; + +using namespace md; + +namespace fs = std::experimental::filesystem; + +#ifdef PRINT_HEAT_MAP +void print_heat_map(const md::HeatMaps& hms, std::string fname, const CalculationParams& params) +{ + spd::debug("Entered print_heat_map"); + std::set mu_vals, lambda_vals; + auto hm_iter = hms.end(); + --hm_iter; + int max_level = hm_iter->first; + + int level_cardinality = 4; + for(int i = 0; i < params.initialization_depth; ++i) { + level_cardinality *= 4; + } + for(int i = params.initialization_depth + 1; i <= max_level; ++i) { + spd::debug("hms.at({}).size = {}, must be {}", i, hms.at(i).size(), level_cardinality); + assert(static_cast(hms.at(i).size()) == level_cardinality); + level_cardinality *= 4; + } + + std::map, Real> hm_x_flat, hm_x_steep, hm_y_flat, hm_y_steep; + + for(const auto& dual_point_value_pair : hms.at(max_level)) { + const DualPoint& k = dual_point_value_pair.first; + spd::debug("HM DP: {}", k); + mu_vals.insert(k.mu()); + lambda_vals.insert(k.lambda()); + } + + std::vector lambda_vals_vec(lambda_vals.begin(), lambda_vals.end()); + std::vector mu_vals_vec(mu_vals.begin(), mu_vals.end()); + + std::ofstream ofs {fname}; + if (not ofs.good()) { + std::cerr << "Cannot write heat map to file " << fname << std::endl; + throw std::runtime_error("Cannot open file for writing heat map"); + } + + std::vector> heatmap_to_print(2 * mu_vals_vec.size(), + std::vector(2 * lambda_vals_vec.size(), 0.0)); + + for(auto axis_type : {AxisType::x_type, AxisType::y_type}) { + bool is_x_type = axis_type == AxisType::x_type; + for(auto angle_type : {AngleType::flat, AngleType::steep}) { + bool is_flat = angle_type == AngleType::flat; + + int mu_idx_begin, mu_idx_end; + + if (is_x_type) { + mu_idx_begin = mu_vals_vec.size() - 1; + mu_idx_end = -1; + } else { + mu_idx_begin = 0; + mu_idx_end = mu_vals_vec.size(); + } + + int lambda_idx_begin, lambda_idx_end; + + if (is_flat) { + lambda_idx_begin = 0; + lambda_idx_end = lambda_vals_vec.size(); + } else { + lambda_idx_begin = lambda_vals_vec.size() - 1; + lambda_idx_end = -1; + } + + int mu_idx_final = is_x_type ? 0 : mu_vals_vec.size(); + + for(int mu_idx = mu_idx_begin; mu_idx != mu_idx_end; (mu_idx_begin < mu_idx_end) ? mu_idx++ : mu_idx--) { + Real mu = mu_vals_vec.at(mu_idx); + + if (mu == 0.0 and axis_type == AxisType::x_type) + continue; + + int lambda_idx_final = is_flat ? 0 : lambda_vals_vec.size(); + + for(int lambda_idx = lambda_idx_begin; + lambda_idx != lambda_idx_end; + (lambda_idx_begin < lambda_idx_end) ? lambda_idx++ : lambda_idx--) { + + Real lambda = lambda_vals_vec.at(lambda_idx); + + if (lambda == 0.0 and angle_type == AngleType::flat) + continue; + + DualPoint dp(axis_type, angle_type, lambda, mu); + Real dist_value = hms.at(max_level).at(dp); + + heatmap_to_print.at(mu_idx_final).at(lambda_idx_final) = dist_value; + +// fmt::print("HM, dp = {}, mu_idx_final = {}, lambda_idx_final = {}, value = {}\n", dp, mu_idx_final, +// lambda_idx_final, dist_value); + + lambda_idx_final++; + } + mu_idx_final++; + } + } + } + + for(size_t m_idx = 0; m_idx < heatmap_to_print.size(); ++m_idx) { + for(size_t l_idx = 0; l_idx < heatmap_to_print.at(m_idx).size(); ++l_idx) { + ofs << heatmap_to_print.at(m_idx).at(l_idx) << " "; + } + ofs << std::endl; + } + + ofs.close(); + spd::debug("Exit print_heat_map"); +} +#endif + +int main(int argc, char** argv) +{ + spdlog::set_level(spdlog::level::info); + + using opts::Option; + using opts::PosOption; + opts::Options ops; + + CalculationParams params; + + bool help = false; + +#ifdef MD_EXPERIMENTAL_TIMING + bool no_stop_asap = false; +#endif + + +#ifdef PRINT_HEAT_MAP + bool heatmap_only = false; +#endif + + // default values are the best for practical use + // if compiled with MD_EXPERIMENTAL_TIMING, user can supply + // different bounds and traverse strategies + // separated by commas, all combinations will be timed. + // See corresponding >> operators in matching_distance.h + // for possible values. + std::string bounds_list_str = "local_combined"; + std::string traverse_list_str = "BFS"; + + ops >> Option('m', "max-iterations", params.max_depth, "maximal number of iterations (refinements)") + >> Option('e', "max-error", params.delta, "error threshold") + >> Option('d', "dim", params.dim, "dim") + >> Option('i', "initial-depth", params.initialization_depth, "initialization depth") +#ifdef MD_EXPERIMENTAL_TIMING + >> Option("no-stop-asap", no_stop_asap, + "don't stop looping over points, if cell cannot be pruned (asap is on by default)") + >> Option("bounds", bounds_list_str, "bounds to use, separated by ,") + >> Option("traverse", traverse_list_str, "traverse strategy to use, separated by ,") +#endif +#ifdef PRINT_HEAT_MAP + >> Option("heatmap-only", heatmap_only, "only save heatmap (bruteforce)") +#endif + >> Option('h', "help", help, "show help message"); + + std::string fname_a; + std::string fname_b; + + if (!ops.parse(argc, argv) || help || !(ops >> PosOption(fname_a) >> PosOption(fname_b))) { + std::cerr << "Usage: " << argv[0] << " bifiltration-file-1 bifiltration-file-2\n" << ops << std::endl; + return 1; + } + +#ifdef MD_EXPERIMENTAL_TIMING + params.stop_asap = not no_stop_asap; +#else + params.stop_asap = true; +#endif + + auto bounds_list = split_by_delim(bounds_list_str, ','); + auto traverse_list = split_by_delim(traverse_list_str, ','); + + Bifiltration bif_a(fname_a); + Bifiltration bif_b(fname_b); + + bif_a.sanity_check(); + bif_b.sanity_check(); + + spd::debug("Read bifiltrations {} {}", fname_a, fname_b); + + std::vector bound_strategies; + std::vector traverse_strategies; + + for(std::string s : bounds_list) { + bound_strategies.push_back(bs_from_string(s)); + } + + for(std::string s : traverse_list) { + traverse_strategies.push_back(ts_from_string(s)); + } + + +#ifdef MD_EXPERIMENTAL_TIMING + + for(auto bs : bound_strategies) { + for(auto ts : traverse_strategies) { + spd::info("Will test combination {} {}", bs, ts); + } + } + + struct ExperimentResult { + CalculationParams params {CalculationParams()}; + int n_hera_calls {0}; + double total_milliseconds_elapsed {0}; + Real distance {0}; + Real actual_error {std::numeric_limits::max()}; + int actual_max_depth {0}; + + int x_wins {0}; + int y_wins {0}; + int ad_wins {0}; + + int seconds_elapsed() const + { + return static_cast(total_milliseconds_elapsed / 1000); + } + + double savings_ratio_old() const + { + long int max_possible_calls = 0; + long int calls_on_level = 4; + for(int i = 0; i <= actual_max_depth; ++i) { + max_possible_calls += calls_on_level; + calls_on_level *= 4; + } + return static_cast(n_hera_calls) / static_cast(max_possible_calls); + } + + double savings_ratio() const + { + return static_cast(n_hera_calls) / calls_on_actual_max_depth(); + } + + long long int calls_on_actual_max_depth() const + { + long long int result = 1; + for(int i = 0; i < actual_max_depth; ++i) { + result *= 4; + } + return result; + } + + ExperimentResult() { } + + ExperimentResult(CalculationParams p, int nhc, double tme, double d) + : + params(p), n_hera_calls(nhc), total_milliseconds_elapsed(tme), distance(d) { } + }; + + const int n_repetitions = 1; + +#ifdef PRINT_HEAT_MAP + if (heatmap_only) { + bound_strategies.clear(); + bound_strategies.push_back(BoundStrategy::bruteforce); + traverse_strategies.clear(); + traverse_strategies.push_back(TraverseStrategy::breadth_first); + } +#endif + + std::map, ExperimentResult> results; + for(BoundStrategy bound_strategy : bound_strategies) { + for(TraverseStrategy traverse_strategy : traverse_strategies) { + CalculationParams params_experiment; + params_experiment.bound_strategy = bound_strategy; + params_experiment.traverse_strategy = traverse_strategy; + params_experiment.max_depth = params.max_depth; + params_experiment.initialization_depth = params.initialization_depth; + params_experiment.delta = params.delta; + params_experiment.dim = params.dim; + params_experiment.hera_epsilon = params.hera_epsilon; + params_experiment.stop_asap = params.stop_asap; + + if (traverse_strategy == TraverseStrategy::depth_first and bound_strategy == BoundStrategy::bruteforce) + continue; + + // if bruteforce, clamp max iterations number to 7, + // save user-provided max_iters in user_max_iters and restore it later. + // remember: params is passed by reference to return real relative error and heat map + + int user_max_iters = params.max_depth; +#ifdef PRINT_HEAT_MAP + if (bound_strategy == BoundStrategy::bruteforce and not heatmap_only) { + params_experiment.max_depth = std::min(7, params.max_depth); + } +#endif + double total_milliseconds_elapsed = 0; + int total_n_hera_calls = 0; + Real dist; + for(int i = 0; i < n_repetitions; ++i) { + spd::debug("Processing bound_strategy {}, traverse_strategy {}, iteration = {}", bound_strategy, + traverse_strategy, i); + auto t1 = std::chrono::high_resolution_clock().now(); + dist = matching_distance(bif_a, bif_b, params_experiment); + auto t2 = std::chrono::high_resolution_clock().now(); + total_milliseconds_elapsed += std::chrono::duration_cast( + t2 - t1).count(); + total_n_hera_calls += params_experiment.n_hera_calls; + } + + auto key = std::make_tuple(bound_strategy, traverse_strategy); + results[key].params = params_experiment; + results[key].n_hera_calls = total_n_hera_calls / n_repetitions; + results[key].total_milliseconds_elapsed = total_milliseconds_elapsed / n_repetitions; + results[key].distance = dist; + results[key].actual_error = params_experiment.actual_error; + results[key].actual_max_depth = params_experiment.actual_max_depth; + + spd::info( + "Done (bound = {}, traverse = {}), n_hera_calls = {}, time = {} sec, d = {}, error = {}, savings = {}, max_depth = {}", + bound_strategy, traverse_strategy, results[key].n_hera_calls, results[key].seconds_elapsed(), + dist, + params_experiment.actual_error, results[key].savings_ratio(), results[key].actual_max_depth); + + if (bound_strategy == BoundStrategy::bruteforce) { params_experiment.max_depth = user_max_iters; } + +#ifdef PRINT_HEAT_MAP + if (bound_strategy == BoundStrategy::bruteforce) { + fs::path fname_a_path {fname_a.c_str()}; + fs::path fname_b_path {fname_b.c_str()}; + fs::path fname_a_wo = fname_a_path.filename(); + fs::path fname_b_wo = fname_b_path.filename(); + std::string heat_map_fname = fmt::format("{0}_{1}_dim_{2}_weighted_values_xyp.txt", fname_a_wo.string(), + fname_b_wo.string(), params_experiment.dim); + fs::path path_hm = fname_a_path.replace_filename(fs::path(heat_map_fname.c_str())); + spd::debug("Saving heatmap to {}", heat_map_fname); + print_heat_map(params_experiment.heat_maps, path_hm.string(), params); + } +#endif + spd::debug("Finished processing bound_strategy {}", bound_strategy); + } + } + +// std::cout << "File_1;File_2;Boundstrategy;TraverseStrategy;InitalDepth;NHeraCalls;SavingsRatio;Time;Distance;Error;PushStrategy;MaxDepth;CallsOnMaxDepth;Delta;Dimension" << std::endl; + for(auto bs : bound_strategies) { + for(auto ts : traverse_strategies) { + auto key = std::make_tuple(bs, ts); + + fs::path fname_a_path {fname_a.c_str()}; + fs::path fname_b_path {fname_b.c_str()}; + fs::path fname_a_wo = fname_a_path.filename(); + fs::path fname_b_wo = fname_b_path.filename(); + + std::cout << fname_a_wo.string() << ";" << fname_b_wo.string() << ";" << bs << ";" << ts << ";"; + std::cout << results[key].params.initialization_depth << ";"; + std::cout << results[key].n_hera_calls << ";" + << results[key].savings_ratio() << ";" + << results[key].total_milliseconds_elapsed << ";" + << results[key].distance << ";" + << results[key].actual_error << ";" + << "xyp" << ";" + << results[key].actual_max_depth << ";" + << results[key].calls_on_actual_max_depth() << ";" + << params.delta << ";" + << params.dim + << std::endl; + } + } +#else + params.bound_strategy = bound_strategies.back(); + params.traverse_strategy = traverse_strategies.back(); + + spd::debug("Will use {} bound, {} traverse strategy", params.bound_strategy, params.traverse_strategy); + + Real dist = matching_distance(bif_a, bif_b, params); + std::cout << dist << std::endl; +#endif + return 0; +} -- cgit v1.2.3 From 490fed367bb97a96b90caa6ef04265c063d91df1 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Sat, 7 Mar 2020 08:44:20 +0100 Subject: Fix multiple bugs in matching distance for modules, add example. --- README.txt | 15 +++++-- bottleneck/README | 2 +- matching/CMakeLists.txt | 4 ++ matching/README.md | 72 +++++++++++++++++++++++++++++++-- matching/example/matching_dist.cpp | 3 -- matching/include/matching_distance.h | 2 +- matching/include/matching_distance.hpp | 1 + matching/include/persistence_module.h | 2 +- matching/include/persistence_module.hpp | 19 ++++++--- 9 files changed, 101 insertions(+), 19 deletions(-) (limited to 'matching/example') diff --git a/README.txt b/README.txt index ae3f6a0..48a77c0 100644 --- a/README.txt +++ b/README.txt @@ -1,5 +1,6 @@ This repository contains software to compute bottleneck and Wasserstein -distances between persistence diagrams. +distances between persistence diagrams, and matching distance +between 2-parameter persistence modules and (1-critical) bi-filtrations. The software is licensed under BSD license, see license.txt file. @@ -9,13 +10,14 @@ you probably do not need to worry about that. See README files in subdirectories for usage and building. If you use Hera in your project, we would appreciate if you -cite the corresponding paper: +cite the corresponding paper. + +Bottleneck or Wasserstein distance: + Michael Kerber, Dmitriy Morozov, and Arnur Nigmetov, "Geometry Helps to Compare Persistence Diagrams.", Journal of Experimental Algorithmics, vol. 22, 2017, pp. 1--20. (conference version: ALENEX 2016). -The BibTeX is below: - @article{jea_hera, title={Geometry helps to compare persistence diagrams}, @@ -26,3 +28,8 @@ The BibTeX is below: year={2017}, publisher={ACM New York, NY, USA} } + +Matching distance: + +Michael Kerber, Arnur Nigmetov, "Efficient Approximation of the Matching +Distance for 2-parameter persistence.", SoCG 2020 diff --git a/bottleneck/README b/bottleneck/README index 8b368af..c04390b 100644 --- a/bottleneck/README +++ b/bottleneck/README @@ -1,6 +1,6 @@ Accompanying paper: M. Kerber, D. Morozov, A. Nigmetov. Geometry Helps To Compare Persistence Diagrams (ALENEX 2016, http://www.geometrie.tugraz.at/nigmetov/geom_dist.pdf) -Bug reports can be sent to "nigmetov EMAIL SIGN tugraz DOT at". +Bug reports can be sent to "anigmetov EMAIL SIGN lbl DOT gov". # Dependencies diff --git a/matching/CMakeLists.txt b/matching/CMakeLists.txt index 3ee0f6b..a391d84 100644 --- a/matching/CMakeLists.txt +++ b/matching/CMakeLists.txt @@ -48,5 +48,9 @@ endif() add_executable(matching_dist "example/matching_dist.cpp" ${MD_HEADERS} ${BT_HEADERS} ) target_link_libraries(matching_dist PUBLIC ${libraries}) +add_executable(module_example "example/module_example.cpp" ${MD_HEADERS} ${BT_HEADERS} ) +target_link_libraries(module_example PUBLIC ${libraries}) + + add_executable(matching_distance_test ${SRC_TEST_FILES} ${BT_HEADERS} ${MD_HEADERS}) target_link_libraries(matching_distance_test PUBLIC ${libraries}) diff --git a/matching/README.md b/matching/README.md index 7dc1874..fc97441 100644 --- a/matching/README.md +++ b/matching/README.md @@ -1,5 +1,69 @@ -Matching distance between bifiltrations. +# Matching distance between bifiltrations and 2-persistence modules. -Currently supports only 1-critical bi-filtrations -in PHAT-like format: boundary matrix + critical values -of a simplex in each row +## Accompanying paper +M. Kerber, A. Nigmetov, +Efficient Approximation of the Matching Distance for 2-parameter persistence. +SoCG 2020. + +Bug reports can be sent to "anigmetov EMAIL SIGN lbl DOT gov". + +## Dependencies + +* Your compiler must support C++11. +* Boost. + +## Usage: + +1. To use a standalone command-line utility matching_dist: + +`./matching_dist -d dimension -e relative_error file1 file2` + +If `relative_error` is not specified, the default 0.1 is used. +If `dimension` is not specified, the default value is 0. +Run `./matching_dist` without parameters to see other options. + +The output is an approximation of the exact distance (which is assumed to be non-zero). +Precisely: if *d_exact* is the true distance and *d_approx* is the output, then + +> | d_exact - d_approx | / d_exact < relative_error. + +Files file1 and file2 must contain 1-critical bi-filtrations in a plain text format which is similar to PHAT. The first line of a file must say *bifiltration_phat_like*. The second line contains the total number of simplices *N*. The next *N* lines contain simplices in the format *dim x y boundary*. +* *dim*: the dimension of the simplex +* *x, y*: coordinates of the critical value +* *boundary*: indices of simplices forming the boundary of the current simplex. Indices are separated by space. +* Simplices are indexed starting from 0. + +For example, the bi-filtration of a segment with vertices appearing at (0,0) and the 1-segment appearing at (3,4) shall be written as: + +> bifiltration_phat_like +> 3 +> \# lines starting with \# are ignored +> \# vertex A has dimension 0, hence no boundary, its index is 0 +> 0 0 0 +> \# vertex B has index 1 +> 0 0 0 +> \# 1-dimensional simplex {A, B} +> 1 3 4 0 1 + +2. To use from your code. + +Here you can compute the matching distance either between bi-filtrations or between persistence modules. +First, you need to include `#include "matching_distance.h"` Practically every class you need is parameterized by Real type, which should be either float or double. The header provides two functions called `matching_distance.` +See `example/module_example.cpp` for additional details. + +## License + +See `licence.txt` in the repository root folder. + +## Building + +CMakeLists.txt can be used to build the command-line utility in the standard +way. On Linux/Mac/Windows with Cygwin: +> `mkdir build` +> `cd build` +> `cmake ..` +> `make` + +On Windows with Visual Studio: use `cmake-gui` to create the solution in build directory and build it with VS. + +The library itself is header-only and does not require separate compilation. diff --git a/matching/example/matching_dist.cpp b/matching/example/matching_dist.cpp index 13e5c6d..734f6cc 100644 --- a/matching/example/matching_dist.cpp +++ b/matching/example/matching_dist.cpp @@ -13,9 +13,6 @@ #include "spdlog/spdlog.h" #include "spdlog/fmt/ostr.h" -//#include "persistence_module.h" -#include "bifiltration.h" -#include "box.h" #include "matching_distance.h" using Real = double; diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h index 276cc3a..e82a97c 100644 --- a/matching/include/matching_distance.h +++ b/matching/include/matching_distance.h @@ -152,7 +152,7 @@ namespace md { int max_depth {6}; // maximal number of refinenemnts int initialization_depth {3}; int dim {0}; // in which dim to calculate the distance; use ALL_DIMENSIONS to get max over all dims - BoundStrategy bound_strategy {BoundStrategy::bruteforce}; + BoundStrategy bound_strategy {BoundStrategy::local_combined}; TraverseStrategy traverse_strategy {TraverseStrategy::breadth_first}; bool tolerate_max_iter_exceeded {true}; Real actual_error {std::numeric_limits::max()}; diff --git a/matching/include/matching_distance.hpp b/matching/include/matching_distance.hpp index 48c8464..7cff073 100644 --- a/matching/include/matching_distance.hpp +++ b/matching/include/matching_distance.hpp @@ -193,6 +193,7 @@ namespace md { // make all coordinates non-negative auto min_coord = std::min(module_a_.minimal_coordinate(), module_b_.minimal_coordinate()); + spd::debug("in DistanceCalculator ctor, min_coord = {}", min_coord); if (min_coord < 0) { module_a_.translate(-min_coord); module_b_.translate(-min_coord); diff --git a/matching/include/persistence_module.h b/matching/include/persistence_module.h index e99771f..4a261bb 100644 --- a/matching/include/persistence_module.h +++ b/matching/include/persistence_module.h @@ -47,7 +47,7 @@ namespace md { IndexVec components_; Relation() {} - Relation(const Point& _pos, const IndexVec& _components); + Relation(const Point& _pos, const IndexVec& _components) : position_(_pos), components_(_components) {} Real get_x() const { return position_.x; } Real get_y() const { return position_.y; } diff --git a/matching/include/persistence_module.hpp b/matching/include/persistence_module.hpp index 6e49b2e..128fed9 100644 --- a/matching/include/persistence_module.hpp +++ b/matching/include/persistence_module.hpp @@ -34,10 +34,10 @@ namespace md { template void ModulePresentation::init_boundaries() { - max_x_ = std::numeric_limits::max(); - max_y_ = std::numeric_limits::max(); - min_x_ = -std::numeric_limits::max(); - min_y_ = -std::numeric_limits::max(); + max_x_ = -std::numeric_limits::max(); + max_y_ = -std::numeric_limits::max(); + min_x_ = std::numeric_limits::max(); + min_y_ = std::numeric_limits::max(); for(const auto& gen : positions_) { min_x_ = std::min(gen.x, min_x_); @@ -55,6 +55,7 @@ namespace md { generators_(_generators), relations_(_relations) { + positions_ = concat_gen_and_rel_positions(generators_, relations_); init_boundaries(); } @@ -85,6 +86,7 @@ namespace md { void ModulePresentation::project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const { + spd::debug("Enter project_generators, slice = {}", slice); size_t num_gens = generators_.size(); RealVec gen_values; @@ -97,6 +99,7 @@ namespace md { projections.reserve(num_gens); for(auto i : sorted_indices) { projections.push_back(gen_values[i]); + spd::debug("added push = {}", gen_values[i]); } } @@ -104,6 +107,8 @@ namespace md { void ModulePresentation::project_relations(const DualPoint& slice, IndexVec& sorted_rel_indices, RealVec& projections) const { + + spd::debug("Enter project_relations, slice = {}", slice); size_t num_rels = relations_.size(); RealVec rel_values; @@ -116,12 +121,14 @@ namespace md { projections.reserve(num_rels); for(auto i : sorted_rel_indices) { projections.push_back(rel_values[i]); + spd::debug("added push = {}", rel_values[i]); } } template Diagram ModulePresentation::weighted_slice_diagram(const DualPoint& slice) const { + spd::debug("Enter weighted_slice_diagram, slice = {}", slice); IndexVec sorted_gen_indices, sorted_rel_indices; RealVec gen_projections, rel_projections; @@ -138,7 +145,8 @@ namespace md { j = sorted_gen_indices[j]; } std::sort(current_relation.begin(), current_relation.end()); - phat_matrix.set_dim(i, current_relation.size()); + // modules do not have dimension, set all to 0 + phat_matrix.set_dim(i, 0); phat_matrix.set_col(i, current_relation); } @@ -154,6 +162,7 @@ namespace md { bool is_finite_pair = new_pair.second != phat::k_infinity_index; Real birth = gen_projections.at(new_pair.first); Real death = is_finite_pair ? rel_projections.at(new_pair.second) : real_inf; + spd::debug("i = {}, birth = {}, death = {}", i, new_pair.first, new_pair.second); if (birth != death) { dgm.emplace_back(birth, death); } -- cgit v1.2.3 From 14e91d6c3ad81a1ec763d75a28f20fb689e5166e Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Mon, 9 Mar 2020 06:26:49 +0100 Subject: Add tests for module slice restriction. --- matching/example/module_example.cpp | 68 ++++++++++++++++++ matching/include/matching_distance.h | 17 +++-- matching/include/persistence_module.h | 7 ++ matching/include/persistence_module.hpp | 24 +++++-- matching/tests/test_matching_distance.cpp | 7 +- matching/tests/test_module.cpp | 114 ++++++++++++++++++++++++++++++ 6 files changed, 223 insertions(+), 14 deletions(-) create mode 100644 matching/example/module_example.cpp create mode 100644 matching/tests/test_module.cpp (limited to 'matching/example') diff --git a/matching/example/module_example.cpp b/matching/example/module_example.cpp new file mode 100644 index 0000000..c160c21 --- /dev/null +++ b/matching/example/module_example.cpp @@ -0,0 +1,68 @@ +#include +#include "matching_distance.h" + +using namespace md; + +int main(int argc, char** argv) +{ + // create generators. + // A generator is a point in plane, + // generators are stored in a vector of points: + PointVec gens_a; + + // module A will have one generator that appears at point (0, 0) + gens_a.emplace_back(0, 0); + + // relations are stored in a vector of relations + using RelationVec = ModulePresentation::RelVec; + RelationVec rels_a; + + // A relation is a struct with position and column + using Relation = ModulePresentation::Relation; + + // at this point the relation rel_a_1 will appear: + Point rel_a_1_position { 1, 1 }; + + // vector IndexVec contains non-zero indices of the corresponding relation + // (we work over Z/2). Since we have one generator only, the relation + // contains only one entry, 0 + IndexVec rel_a_1_components { 0 }; + + // construct a relation from position and column: + Relation rel_a_1 { rel_a_1_position, rel_a_1_components }; + + // and add it to a vector of relations + rels_a.push_back(rel_a_1); + + // after populating vectors of generators and relations + // construct a module: + ModulePresentation module_a { gens_a, rels_a }; + + + // same for module_b. It will also have just one + // generator and one relation, but at different positions. + + PointVec gens_b; + gens_b.emplace_back(1, 1); + + RelationVec rels_b; + + Point rel_b_1_position { 2, 2 }; + IndexVec rel_b_1_components { 0 }; + + rels_b.emplace_back(rel_b_1_position, rel_b_1_components); + + ModulePresentation module_b { gens_b, rels_b }; + + // create CalculationParams + CalculationParams params; + // set relative error to 10 % : + params.delta = 0.1; + // go at most 8 levels deep in quadtree: + params.max_depth = 8; + + double dist = matching_distance(module_a, module_b, params); + std::cout << "dist = " << dist << std::endl; + + return 0; +} diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h index e82a97c..e1679dc 100644 --- a/matching/include/matching_distance.h +++ b/matching/include/matching_distance.h @@ -149,20 +149,20 @@ namespace md { Real hera_epsilon {0.001}; // relative error in hera call Real delta {0.1}; // relative error for matching distance - int max_depth {6}; // maximal number of refinenemnts - int initialization_depth {3}; + int max_depth {8}; // maximal number of refinenemnts + int initialization_depth {2}; int dim {0}; // in which dim to calculate the distance; use ALL_DIMENSIONS to get max over all dims BoundStrategy bound_strategy {BoundStrategy::local_combined}; TraverseStrategy traverse_strategy {TraverseStrategy::breadth_first}; - bool tolerate_max_iter_exceeded {true}; + bool tolerate_max_iter_exceeded {false}; Real actual_error {std::numeric_limits::max()}; int actual_max_depth {0}; int n_hera_calls {0}; // for experiments only; is set in matching_distance function, input value is ignored // stop looping over points immediately, if current point's displacement is too large // to prune the cell - // if true, cells are pruned immediately, and bounds may be unreliable - // (we just return something large enough to prune the cell) + // if true, cells are pruned immediately, and bounds may increase + // (just return something large enough to not prune the cell) bool stop_asap { true }; // print statistics on each quad-tree level @@ -188,8 +188,11 @@ namespace md { Real distance(); int get_hera_calls_number() const; -// for tests - make everything public -// private: + +#ifndef MD_TEST_CODE + private: +#endif + DiagramProvider module_a_; DiagramProvider module_b_; diff --git a/matching/include/persistence_module.h b/matching/include/persistence_module.h index 4a261bb..b68c21e 100644 --- a/matching/include/persistence_module.h +++ b/matching/include/persistence_module.h @@ -80,7 +80,9 @@ namespace md { PointVec positions() const; +#ifndef MD_TEST_CODE private: +#endif PointVec generators_; std::vector relations_; @@ -94,8 +96,13 @@ namespace md { Box bounding_box_; void init_boundaries(); + void project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; void project_relations(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; + + void get_slice_projection_matrix(const DualPoint& slice, phat::boundary_matrix<>& phat_matrix, + RealVec& gen_projections, RealVec& rel_projections) const; + }; } // namespace md diff --git a/matching/include/persistence_module.hpp b/matching/include/persistence_module.hpp index 128fed9..233a70d 100644 --- a/matching/include/persistence_module.hpp +++ b/matching/include/persistence_module.hpp @@ -116,7 +116,11 @@ namespace md { for(const auto& rel : relations_) { rel_values.push_back(slice.weighted_push(rel.position_)); } + sorted_rel_indices = get_sorted_indices(rel_values); + + spd::debug("rel_values = {}, sorted_rel_indices = {}", container_to_string(rel_values), container_to_string(sorted_rel_indices)); + projections.clear(); projections.reserve(num_rels); for(auto i : sorted_rel_indices) { @@ -125,18 +129,18 @@ namespace md { } } + template - Diagram ModulePresentation::weighted_slice_diagram(const DualPoint& slice) const + void ModulePresentation::get_slice_projection_matrix(const DualPoint& slice, + phat::boundary_matrix<>& phat_matrix, + RealVec& gen_projections, RealVec& rel_projections) const { spd::debug("Enter weighted_slice_diagram, slice = {}", slice); IndexVec sorted_gen_indices, sorted_rel_indices; - RealVec gen_projections, rel_projections; project_generators(slice, sorted_gen_indices, gen_projections); project_relations(slice, sorted_rel_indices, rel_projections); - phat::boundary_matrix<> phat_matrix; - phat_matrix.set_num_cols(relations_.size()); for(Index i = 0; i < (Index) relations_.size(); i++) { @@ -149,6 +153,18 @@ namespace md { phat_matrix.set_dim(i, 0); phat_matrix.set_col(i, current_relation); } + } + + + template + Diagram ModulePresentation::weighted_slice_diagram(const DualPoint& slice) const + { + spd::debug("Enter weighted_slice_diagram, slice = {}", slice); + + RealVec gen_projections, rel_projections; + phat::boundary_matrix<> phat_matrix; + + get_slice_projection_matrix(slice, phat_matrix, gen_projections, rel_projections); phat::persistence_pairs phat_persistence_pairs; phat::compute_persistence_pairs(phat_persistence_pairs, phat_matrix); diff --git a/matching/tests/test_matching_distance.cpp b/matching/tests/test_matching_distance.cpp index 82da530..aa08cfe 100644 --- a/matching/tests/test_matching_distance.cpp +++ b/matching/tests/test_matching_distance.cpp @@ -7,6 +7,8 @@ #include "spdlog/spdlog.h" #include "spdlog/fmt/ostr.h" +#define MD_TEST_CODE + #include "common_util.h" #include "simplex.h" #include "matching_distance.h" @@ -130,8 +132,8 @@ TEST_CASE("Bifiltrations from file", "[matching_distance][small_example][lesnick { std::string fname_a, fname_b; - fname_a = "../src/tests/prism_1.bif"; - fname_b = "../src/tests/prism_2.bif"; + fname_a = "../tests/prism_1.bif"; + fname_b = "../tests/prism_2.bif"; Bifiltration bif_a(fname_a); Bifiltration bif_b(fname_b); @@ -164,4 +166,3 @@ TEST_CASE("Bifiltrations from file", "[matching_distance][small_example][lesnick } } } - diff --git a/matching/tests/test_module.cpp b/matching/tests/test_module.cpp new file mode 100644 index 0000000..0f239d5 --- /dev/null +++ b/matching/tests/test_module.cpp @@ -0,0 +1,114 @@ +#include "catch/catch.hpp" + +#include +#include +#include + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +#define MD_TEST_CODE + +#include "common_util.h" +#include "persistence_module.h" +#include "matching_distance.h" + +using Real = double; +using Point = md::Point; +using Bifiltration = md::Bifiltration; +using BifiltrationProxy = md::BifiltrationProxy; +using CalculationParams = md::CalculationParams; +using CellWithValue = md::CellWithValue; +using DualPoint = md::DualPoint; +using DualBox = md::DualBox; +using BoundStrategy = md::BoundStrategy; +using TraverseStrategy = md::TraverseStrategy; +using AxisType = md::AxisType; +using AngleType = md::AngleType; +using ValuePoint = md::ValuePoint; +using Column = md::Column; +using PointVec = md::PointVec; +using Module = md::ModulePresentation; +using Relation = Module::Relation; +using RelationVec = Module::RelVec; +using IndexVec = md::IndexVec; + +using md::k_corner_vps; + +namespace spd = spdlog; + +TEST_CASE("Module projection", "[module][projection]") +{ + PointVec gens; + gens.emplace_back(1, 1); // A + gens.emplace_back(2, 3); // B + gens.emplace_back(3, 2); // C + + RelationVec rels; + + Point rel_x_position { 3.98, 2.47 }; + IndexVec rel_x_components { 0, 2 }; // X: A + C = 0 + + Point rel_y_position { 2.5, 4 }; + IndexVec rel_y_components { 0, 1 }; // Y: A + B = 0 + + Point rel_z_position { 5, 5 }; + IndexVec rel_z_components { 1 }; // Z: B = 0 + + + rels.emplace_back(rel_x_position, rel_x_components); + rels.emplace_back(rel_y_position, rel_y_components); + rels.emplace_back(rel_z_position, rel_z_components); + + Module module { gens, rels }; + + { + DualPoint slice_1(AxisType::x_type, AngleType::flat, 6.0 / 7.0, 3.0); + std::vector gen_ps_1, rel_ps_1; + phat::boundary_matrix<> matr_1; + + module.get_slice_projection_matrix(slice_1, matr_1, gen_ps_1, rel_ps_1); + + phat::column c_1_0, c_1_1, c_1_2; + + matr_1.get_col(0, c_1_0); + matr_1.get_col(1, c_1_1); + matr_1.get_col(2, c_1_2); + + phat::column c_1_0_correct { 0, 1}; + phat::column c_1_1_correct { 0, 2}; + phat::column c_1_2_correct { 2 }; + + REQUIRE(c_1_0 == c_1_0_correct); + REQUIRE(c_1_1 == c_1_1_correct); + REQUIRE(c_1_2 == c_1_2_correct); + } + + { + + DualPoint slice_2(AxisType::y_type, AngleType::flat, 2.0 / 9.0, 5.0); + std::vector gen_ps_2, rel_ps_2; + phat::boundary_matrix<> matr_2; + + module.get_slice_projection_matrix(slice_2, matr_2, gen_ps_2, rel_ps_2); + + phat::column c_2_0, c_2_1, c_2_2; + + matr_2.get_col(0, c_2_0); + matr_2.get_col(1, c_2_1); + matr_2.get_col(2, c_2_2); + + phat::column c_2_0_correct { 0, 1}; + phat::column c_2_1_correct { 0, 2}; + phat::column c_2_2_correct { 1 }; + + //std::cerr << "gen_ps_2: " << md::container_to_string(gen_ps_2) << std::endl; + //std::cerr << "rel_ps_2: " << md::container_to_string(rel_ps_2) << std::endl; + + REQUIRE(c_2_0 == c_2_0_correct); + REQUIRE(c_2_1 == c_2_1_correct); + REQUIRE(c_2_2 == c_2_2_correct); + } + + +} -- cgit v1.2.3