From 14e91d6c3ad81a1ec763d75a28f20fb689e5166e Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Mon, 9 Mar 2020 06:26:49 +0100 Subject: Add tests for module slice restriction. --- matching/example/module_example.cpp | 68 ++++++++++++++++++ matching/include/matching_distance.h | 17 +++-- matching/include/persistence_module.h | 7 ++ matching/include/persistence_module.hpp | 24 +++++-- matching/tests/test_matching_distance.cpp | 7 +- matching/tests/test_module.cpp | 114 ++++++++++++++++++++++++++++++ 6 files changed, 223 insertions(+), 14 deletions(-) create mode 100644 matching/example/module_example.cpp create mode 100644 matching/tests/test_module.cpp (limited to 'matching') diff --git a/matching/example/module_example.cpp b/matching/example/module_example.cpp new file mode 100644 index 0000000..c160c21 --- /dev/null +++ b/matching/example/module_example.cpp @@ -0,0 +1,68 @@ +#include +#include "matching_distance.h" + +using namespace md; + +int main(int argc, char** argv) +{ + // create generators. + // A generator is a point in plane, + // generators are stored in a vector of points: + PointVec gens_a; + + // module A will have one generator that appears at point (0, 0) + gens_a.emplace_back(0, 0); + + // relations are stored in a vector of relations + using RelationVec = ModulePresentation::RelVec; + RelationVec rels_a; + + // A relation is a struct with position and column + using Relation = ModulePresentation::Relation; + + // at this point the relation rel_a_1 will appear: + Point rel_a_1_position { 1, 1 }; + + // vector IndexVec contains non-zero indices of the corresponding relation + // (we work over Z/2). Since we have one generator only, the relation + // contains only one entry, 0 + IndexVec rel_a_1_components { 0 }; + + // construct a relation from position and column: + Relation rel_a_1 { rel_a_1_position, rel_a_1_components }; + + // and add it to a vector of relations + rels_a.push_back(rel_a_1); + + // after populating vectors of generators and relations + // construct a module: + ModulePresentation module_a { gens_a, rels_a }; + + + // same for module_b. It will also have just one + // generator and one relation, but at different positions. + + PointVec gens_b; + gens_b.emplace_back(1, 1); + + RelationVec rels_b; + + Point rel_b_1_position { 2, 2 }; + IndexVec rel_b_1_components { 0 }; + + rels_b.emplace_back(rel_b_1_position, rel_b_1_components); + + ModulePresentation module_b { gens_b, rels_b }; + + // create CalculationParams + CalculationParams params; + // set relative error to 10 % : + params.delta = 0.1; + // go at most 8 levels deep in quadtree: + params.max_depth = 8; + + double dist = matching_distance(module_a, module_b, params); + std::cout << "dist = " << dist << std::endl; + + return 0; +} diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h index e82a97c..e1679dc 100644 --- a/matching/include/matching_distance.h +++ b/matching/include/matching_distance.h @@ -149,20 +149,20 @@ namespace md { Real hera_epsilon {0.001}; // relative error in hera call Real delta {0.1}; // relative error for matching distance - int max_depth {6}; // maximal number of refinenemnts - int initialization_depth {3}; + int max_depth {8}; // maximal number of refinenemnts + int initialization_depth {2}; int dim {0}; // in which dim to calculate the distance; use ALL_DIMENSIONS to get max over all dims BoundStrategy bound_strategy {BoundStrategy::local_combined}; TraverseStrategy traverse_strategy {TraverseStrategy::breadth_first}; - bool tolerate_max_iter_exceeded {true}; + bool tolerate_max_iter_exceeded {false}; Real actual_error {std::numeric_limits::max()}; int actual_max_depth {0}; int n_hera_calls {0}; // for experiments only; is set in matching_distance function, input value is ignored // stop looping over points immediately, if current point's displacement is too large // to prune the cell - // if true, cells are pruned immediately, and bounds may be unreliable - // (we just return something large enough to prune the cell) + // if true, cells are pruned immediately, and bounds may increase + // (just return something large enough to not prune the cell) bool stop_asap { true }; // print statistics on each quad-tree level @@ -188,8 +188,11 @@ namespace md { Real distance(); int get_hera_calls_number() const; -// for tests - make everything public -// private: + +#ifndef MD_TEST_CODE + private: +#endif + DiagramProvider module_a_; DiagramProvider module_b_; diff --git a/matching/include/persistence_module.h b/matching/include/persistence_module.h index 4a261bb..b68c21e 100644 --- a/matching/include/persistence_module.h +++ b/matching/include/persistence_module.h @@ -80,7 +80,9 @@ namespace md { PointVec positions() const; +#ifndef MD_TEST_CODE private: +#endif PointVec generators_; std::vector relations_; @@ -94,8 +96,13 @@ namespace md { Box bounding_box_; void init_boundaries(); + void project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; void project_relations(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; + + void get_slice_projection_matrix(const DualPoint& slice, phat::boundary_matrix<>& phat_matrix, + RealVec& gen_projections, RealVec& rel_projections) const; + }; } // namespace md diff --git a/matching/include/persistence_module.hpp b/matching/include/persistence_module.hpp index 128fed9..233a70d 100644 --- a/matching/include/persistence_module.hpp +++ b/matching/include/persistence_module.hpp @@ -116,7 +116,11 @@ namespace md { for(const auto& rel : relations_) { rel_values.push_back(slice.weighted_push(rel.position_)); } + sorted_rel_indices = get_sorted_indices(rel_values); + + spd::debug("rel_values = {}, sorted_rel_indices = {}", container_to_string(rel_values), container_to_string(sorted_rel_indices)); + projections.clear(); projections.reserve(num_rels); for(auto i : sorted_rel_indices) { @@ -125,18 +129,18 @@ namespace md { } } + template - Diagram ModulePresentation::weighted_slice_diagram(const DualPoint& slice) const + void ModulePresentation::get_slice_projection_matrix(const DualPoint& slice, + phat::boundary_matrix<>& phat_matrix, + RealVec& gen_projections, RealVec& rel_projections) const { spd::debug("Enter weighted_slice_diagram, slice = {}", slice); IndexVec sorted_gen_indices, sorted_rel_indices; - RealVec gen_projections, rel_projections; project_generators(slice, sorted_gen_indices, gen_projections); project_relations(slice, sorted_rel_indices, rel_projections); - phat::boundary_matrix<> phat_matrix; - phat_matrix.set_num_cols(relations_.size()); for(Index i = 0; i < (Index) relations_.size(); i++) { @@ -149,6 +153,18 @@ namespace md { phat_matrix.set_dim(i, 0); phat_matrix.set_col(i, current_relation); } + } + + + template + Diagram ModulePresentation::weighted_slice_diagram(const DualPoint& slice) const + { + spd::debug("Enter weighted_slice_diagram, slice = {}", slice); + + RealVec gen_projections, rel_projections; + phat::boundary_matrix<> phat_matrix; + + get_slice_projection_matrix(slice, phat_matrix, gen_projections, rel_projections); phat::persistence_pairs phat_persistence_pairs; phat::compute_persistence_pairs(phat_persistence_pairs, phat_matrix); diff --git a/matching/tests/test_matching_distance.cpp b/matching/tests/test_matching_distance.cpp index 82da530..aa08cfe 100644 --- a/matching/tests/test_matching_distance.cpp +++ b/matching/tests/test_matching_distance.cpp @@ -7,6 +7,8 @@ #include "spdlog/spdlog.h" #include "spdlog/fmt/ostr.h" +#define MD_TEST_CODE + #include "common_util.h" #include "simplex.h" #include "matching_distance.h" @@ -130,8 +132,8 @@ TEST_CASE("Bifiltrations from file", "[matching_distance][small_example][lesnick { std::string fname_a, fname_b; - fname_a = "../src/tests/prism_1.bif"; - fname_b = "../src/tests/prism_2.bif"; + fname_a = "../tests/prism_1.bif"; + fname_b = "../tests/prism_2.bif"; Bifiltration bif_a(fname_a); Bifiltration bif_b(fname_b); @@ -164,4 +166,3 @@ TEST_CASE("Bifiltrations from file", "[matching_distance][small_example][lesnick } } } - diff --git a/matching/tests/test_module.cpp b/matching/tests/test_module.cpp new file mode 100644 index 0000000..0f239d5 --- /dev/null +++ b/matching/tests/test_module.cpp @@ -0,0 +1,114 @@ +#include "catch/catch.hpp" + +#include +#include +#include + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +#define MD_TEST_CODE + +#include "common_util.h" +#include "persistence_module.h" +#include "matching_distance.h" + +using Real = double; +using Point = md::Point; +using Bifiltration = md::Bifiltration; +using BifiltrationProxy = md::BifiltrationProxy; +using CalculationParams = md::CalculationParams; +using CellWithValue = md::CellWithValue; +using DualPoint = md::DualPoint; +using DualBox = md::DualBox; +using BoundStrategy = md::BoundStrategy; +using TraverseStrategy = md::TraverseStrategy; +using AxisType = md::AxisType; +using AngleType = md::AngleType; +using ValuePoint = md::ValuePoint; +using Column = md::Column; +using PointVec = md::PointVec; +using Module = md::ModulePresentation; +using Relation = Module::Relation; +using RelationVec = Module::RelVec; +using IndexVec = md::IndexVec; + +using md::k_corner_vps; + +namespace spd = spdlog; + +TEST_CASE("Module projection", "[module][projection]") +{ + PointVec gens; + gens.emplace_back(1, 1); // A + gens.emplace_back(2, 3); // B + gens.emplace_back(3, 2); // C + + RelationVec rels; + + Point rel_x_position { 3.98, 2.47 }; + IndexVec rel_x_components { 0, 2 }; // X: A + C = 0 + + Point rel_y_position { 2.5, 4 }; + IndexVec rel_y_components { 0, 1 }; // Y: A + B = 0 + + Point rel_z_position { 5, 5 }; + IndexVec rel_z_components { 1 }; // Z: B = 0 + + + rels.emplace_back(rel_x_position, rel_x_components); + rels.emplace_back(rel_y_position, rel_y_components); + rels.emplace_back(rel_z_position, rel_z_components); + + Module module { gens, rels }; + + { + DualPoint slice_1(AxisType::x_type, AngleType::flat, 6.0 / 7.0, 3.0); + std::vector gen_ps_1, rel_ps_1; + phat::boundary_matrix<> matr_1; + + module.get_slice_projection_matrix(slice_1, matr_1, gen_ps_1, rel_ps_1); + + phat::column c_1_0, c_1_1, c_1_2; + + matr_1.get_col(0, c_1_0); + matr_1.get_col(1, c_1_1); + matr_1.get_col(2, c_1_2); + + phat::column c_1_0_correct { 0, 1}; + phat::column c_1_1_correct { 0, 2}; + phat::column c_1_2_correct { 2 }; + + REQUIRE(c_1_0 == c_1_0_correct); + REQUIRE(c_1_1 == c_1_1_correct); + REQUIRE(c_1_2 == c_1_2_correct); + } + + { + + DualPoint slice_2(AxisType::y_type, AngleType::flat, 2.0 / 9.0, 5.0); + std::vector gen_ps_2, rel_ps_2; + phat::boundary_matrix<> matr_2; + + module.get_slice_projection_matrix(slice_2, matr_2, gen_ps_2, rel_ps_2); + + phat::column c_2_0, c_2_1, c_2_2; + + matr_2.get_col(0, c_2_0); + matr_2.get_col(1, c_2_1); + matr_2.get_col(2, c_2_2); + + phat::column c_2_0_correct { 0, 1}; + phat::column c_2_1_correct { 0, 2}; + phat::column c_2_2_correct { 1 }; + + //std::cerr << "gen_ps_2: " << md::container_to_string(gen_ps_2) << std::endl; + //std::cerr << "rel_ps_2: " << md::container_to_string(rel_ps_2) << std::endl; + + REQUIRE(c_2_0 == c_2_0_correct); + REQUIRE(c_2_1 == c_2_1_correct); + REQUIRE(c_2_2 == c_2_2_correct); + } + + +} -- cgit v1.2.3