summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArnur Nigmetov <nigmetov@tugraz.at>2020-03-09 06:26:49 +0100
committerArnur Nigmetov <nigmetov@tugraz.at>2020-03-09 06:26:49 +0100
commit14e91d6c3ad81a1ec763d75a28f20fb689e5166e (patch)
tree54dcb2399a2386bc035de932e74e933b8b7537fe
parent490fed367bb97a96b90caa6ef04265c063d91df1 (diff)
Add tests for module slice restriction.
-rw-r--r--matching/example/module_example.cpp68
-rw-r--r--matching/include/matching_distance.h17
-rw-r--r--matching/include/persistence_module.h7
-rw-r--r--matching/include/persistence_module.hpp24
-rw-r--r--matching/tests/test_matching_distance.cpp7
-rw-r--r--matching/tests/test_module.cpp114
6 files changed, 223 insertions, 14 deletions
diff --git a/matching/example/module_example.cpp b/matching/example/module_example.cpp
new file mode 100644
index 0000000..c160c21
--- /dev/null
+++ b/matching/example/module_example.cpp
@@ -0,0 +1,68 @@
+#include <iostream>
+#include "matching_distance.h"
+
+using namespace md;
+
+int main(int argc, char** argv)
+{
+ // create generators.
+ // A generator is a point in plane,
+ // generators are stored in a vector of points:
+ PointVec<double> gens_a;
+
+ // module A will have one generator that appears at point (0, 0)
+ gens_a.emplace_back(0, 0);
+
+ // relations are stored in a vector of relations
+ using RelationVec = ModulePresentation<double>::RelVec;
+ RelationVec rels_a;
+
+ // A relation is a struct with position and column
+ using Relation = ModulePresentation<double>::Relation;
+
+ // at this point the relation rel_a_1 will appear:
+ Point<double> rel_a_1_position { 1, 1 };
+
+ // vector IndexVec contains non-zero indices of the corresponding relation
+ // (we work over Z/2). Since we have one generator only, the relation
+ // contains only one entry, 0
+ IndexVec rel_a_1_components { 0 };
+
+ // construct a relation from position and column:
+ Relation rel_a_1 { rel_a_1_position, rel_a_1_components };
+
+ // and add it to a vector of relations
+ rels_a.push_back(rel_a_1);
+
+ // after populating vectors of generators and relations
+ // construct a module:
+ ModulePresentation<double> module_a { gens_a, rels_a };
+
+
+ // same for module_b. It will also have just one
+ // generator and one relation, but at different positions.
+
+ PointVec<double> gens_b;
+ gens_b.emplace_back(1, 1);
+
+ RelationVec rels_b;
+
+ Point<double> rel_b_1_position { 2, 2 };
+ IndexVec rel_b_1_components { 0 };
+
+ rels_b.emplace_back(rel_b_1_position, rel_b_1_components);
+
+ ModulePresentation<double> module_b { gens_b, rels_b };
+
+ // create CalculationParams
+ CalculationParams<double> params;
+ // set relative error to 10 % :
+ params.delta = 0.1;
+ // go at most 8 levels deep in quadtree:
+ params.max_depth = 8;
+
+ double dist = matching_distance(module_a, module_b, params);
+ std::cout << "dist = " << dist << std::endl;
+
+ return 0;
+}
diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h
index e82a97c..e1679dc 100644
--- a/matching/include/matching_distance.h
+++ b/matching/include/matching_distance.h
@@ -149,20 +149,20 @@ namespace md {
Real hera_epsilon {0.001}; // relative error in hera call
Real delta {0.1}; // relative error for matching distance
- int max_depth {6}; // maximal number of refinenemnts
- int initialization_depth {3};
+ int max_depth {8}; // maximal number of refinenemnts
+ int initialization_depth {2};
int dim {0}; // in which dim to calculate the distance; use ALL_DIMENSIONS to get max over all dims
BoundStrategy bound_strategy {BoundStrategy::local_combined};
TraverseStrategy traverse_strategy {TraverseStrategy::breadth_first};
- bool tolerate_max_iter_exceeded {true};
+ bool tolerate_max_iter_exceeded {false};
Real actual_error {std::numeric_limits<Real>::max()};
int actual_max_depth {0};
int n_hera_calls {0}; // for experiments only; is set in matching_distance function, input value is ignored
// stop looping over points immediately, if current point's displacement is too large
// to prune the cell
- // if true, cells are pruned immediately, and bounds may be unreliable
- // (we just return something large enough to prune the cell)
+ // if true, cells are pruned immediately, and bounds may increase
+ // (just return something large enough to not prune the cell)
bool stop_asap { true };
// print statistics on each quad-tree level
@@ -188,8 +188,11 @@ namespace md {
Real distance();
int get_hera_calls_number() const;
-// for tests - make everything public
-// private:
+
+#ifndef MD_TEST_CODE
+ private:
+#endif
+
DiagramProvider module_a_;
DiagramProvider module_b_;
diff --git a/matching/include/persistence_module.h b/matching/include/persistence_module.h
index 4a261bb..b68c21e 100644
--- a/matching/include/persistence_module.h
+++ b/matching/include/persistence_module.h
@@ -80,7 +80,9 @@ namespace md {
PointVec<Real> positions() const;
+#ifndef MD_TEST_CODE
private:
+#endif
PointVec<Real> generators_;
std::vector<Relation> relations_;
@@ -94,8 +96,13 @@ namespace md {
Box<Real> bounding_box_;
void init_boundaries();
+
void project_generators(const DualPoint<Real>& slice, IndexVec& sorted_indices, RealVec& projections) const;
void project_relations(const DualPoint<Real>& slice, IndexVec& sorted_indices, RealVec& projections) const;
+
+ void get_slice_projection_matrix(const DualPoint<Real>& slice, phat::boundary_matrix<>& phat_matrix,
+ RealVec& gen_projections, RealVec& rel_projections) const;
+
};
} // namespace md
diff --git a/matching/include/persistence_module.hpp b/matching/include/persistence_module.hpp
index 128fed9..233a70d 100644
--- a/matching/include/persistence_module.hpp
+++ b/matching/include/persistence_module.hpp
@@ -116,7 +116,11 @@ namespace md {
for(const auto& rel : relations_) {
rel_values.push_back(slice.weighted_push(rel.position_));
}
+
sorted_rel_indices = get_sorted_indices(rel_values);
+
+ spd::debug("rel_values = {}, sorted_rel_indices = {}", container_to_string(rel_values), container_to_string(sorted_rel_indices));
+
projections.clear();
projections.reserve(num_rels);
for(auto i : sorted_rel_indices) {
@@ -125,18 +129,18 @@ namespace md {
}
}
+
template<class Real>
- Diagram<Real> ModulePresentation<Real>::weighted_slice_diagram(const DualPoint<Real>& slice) const
+ void ModulePresentation<Real>::get_slice_projection_matrix(const DualPoint<Real>& slice,
+ phat::boundary_matrix<>& phat_matrix,
+ RealVec& gen_projections, RealVec& rel_projections) const
{
spd::debug("Enter weighted_slice_diagram, slice = {}", slice);
IndexVec sorted_gen_indices, sorted_rel_indices;
- RealVec gen_projections, rel_projections;
project_generators(slice, sorted_gen_indices, gen_projections);
project_relations(slice, sorted_rel_indices, rel_projections);
- phat::boundary_matrix<> phat_matrix;
-
phat_matrix.set_num_cols(relations_.size());
for(Index i = 0; i < (Index) relations_.size(); i++) {
@@ -149,6 +153,18 @@ namespace md {
phat_matrix.set_dim(i, 0);
phat_matrix.set_col(i, current_relation);
}
+ }
+
+
+ template<class Real>
+ Diagram<Real> ModulePresentation<Real>::weighted_slice_diagram(const DualPoint<Real>& slice) const
+ {
+ spd::debug("Enter weighted_slice_diagram, slice = {}", slice);
+
+ RealVec gen_projections, rel_projections;
+ phat::boundary_matrix<> phat_matrix;
+
+ get_slice_projection_matrix(slice, phat_matrix, gen_projections, rel_projections);
phat::persistence_pairs phat_persistence_pairs;
phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix);
diff --git a/matching/tests/test_matching_distance.cpp b/matching/tests/test_matching_distance.cpp
index 82da530..aa08cfe 100644
--- a/matching/tests/test_matching_distance.cpp
+++ b/matching/tests/test_matching_distance.cpp
@@ -7,6 +7,8 @@
#include "spdlog/spdlog.h"
#include "spdlog/fmt/ostr.h"
+#define MD_TEST_CODE
+
#include "common_util.h"
#include "simplex.h"
#include "matching_distance.h"
@@ -130,8 +132,8 @@ TEST_CASE("Bifiltrations from file", "[matching_distance][small_example][lesnick
{
std::string fname_a, fname_b;
- fname_a = "../src/tests/prism_1.bif";
- fname_b = "../src/tests/prism_2.bif";
+ fname_a = "../tests/prism_1.bif";
+ fname_b = "../tests/prism_2.bif";
Bifiltration bif_a(fname_a);
Bifiltration bif_b(fname_b);
@@ -164,4 +166,3 @@ TEST_CASE("Bifiltrations from file", "[matching_distance][small_example][lesnick
}
}
}
-
diff --git a/matching/tests/test_module.cpp b/matching/tests/test_module.cpp
new file mode 100644
index 0000000..0f239d5
--- /dev/null
+++ b/matching/tests/test_module.cpp
@@ -0,0 +1,114 @@
+#include "catch/catch.hpp"
+
+#include <sstream>
+#include <iostream>
+#include <string>
+
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/ostr.h"
+
+#define MD_TEST_CODE
+
+#include "common_util.h"
+#include "persistence_module.h"
+#include "matching_distance.h"
+
+using Real = double;
+using Point = md::Point<Real>;
+using Bifiltration = md::Bifiltration<Real>;
+using BifiltrationProxy = md::BifiltrationProxy<Real>;
+using CalculationParams = md::CalculationParams<Real>;
+using CellWithValue = md::CellWithValue<Real>;
+using DualPoint = md::DualPoint<Real>;
+using DualBox = md::DualBox<Real>;
+using BoundStrategy = md::BoundStrategy;
+using TraverseStrategy = md::TraverseStrategy;
+using AxisType = md::AxisType;
+using AngleType = md::AngleType;
+using ValuePoint = md::ValuePoint;
+using Column = md::Column;
+using PointVec = md::PointVec<Real>;
+using Module = md::ModulePresentation<Real>;
+using Relation = Module::Relation;
+using RelationVec = Module::RelVec;
+using IndexVec = md::IndexVec;
+
+using md::k_corner_vps;
+
+namespace spd = spdlog;
+
+TEST_CASE("Module projection", "[module][projection]")
+{
+ PointVec gens;
+ gens.emplace_back(1, 1); // A
+ gens.emplace_back(2, 3); // B
+ gens.emplace_back(3, 2); // C
+
+ RelationVec rels;
+
+ Point rel_x_position { 3.98, 2.47 };
+ IndexVec rel_x_components { 0, 2 }; // X: A + C = 0
+
+ Point rel_y_position { 2.5, 4 };
+ IndexVec rel_y_components { 0, 1 }; // Y: A + B = 0
+
+ Point rel_z_position { 5, 5 };
+ IndexVec rel_z_components { 1 }; // Z: B = 0
+
+
+ rels.emplace_back(rel_x_position, rel_x_components);
+ rels.emplace_back(rel_y_position, rel_y_components);
+ rels.emplace_back(rel_z_position, rel_z_components);
+
+ Module module { gens, rels };
+
+ {
+ DualPoint slice_1(AxisType::x_type, AngleType::flat, 6.0 / 7.0, 3.0);
+ std::vector<Real> gen_ps_1, rel_ps_1;
+ phat::boundary_matrix<> matr_1;
+
+ module.get_slice_projection_matrix(slice_1, matr_1, gen_ps_1, rel_ps_1);
+
+ phat::column c_1_0, c_1_1, c_1_2;
+
+ matr_1.get_col(0, c_1_0);
+ matr_1.get_col(1, c_1_1);
+ matr_1.get_col(2, c_1_2);
+
+ phat::column c_1_0_correct { 0, 1};
+ phat::column c_1_1_correct { 0, 2};
+ phat::column c_1_2_correct { 2 };
+
+ REQUIRE(c_1_0 == c_1_0_correct);
+ REQUIRE(c_1_1 == c_1_1_correct);
+ REQUIRE(c_1_2 == c_1_2_correct);
+ }
+
+ {
+
+ DualPoint slice_2(AxisType::y_type, AngleType::flat, 2.0 / 9.0, 5.0);
+ std::vector<Real> gen_ps_2, rel_ps_2;
+ phat::boundary_matrix<> matr_2;
+
+ module.get_slice_projection_matrix(slice_2, matr_2, gen_ps_2, rel_ps_2);
+
+ phat::column c_2_0, c_2_1, c_2_2;
+
+ matr_2.get_col(0, c_2_0);
+ matr_2.get_col(1, c_2_1);
+ matr_2.get_col(2, c_2_2);
+
+ phat::column c_2_0_correct { 0, 1};
+ phat::column c_2_1_correct { 0, 2};
+ phat::column c_2_2_correct { 1 };
+
+ //std::cerr << "gen_ps_2: " << md::container_to_string(gen_ps_2) << std::endl;
+ //std::cerr << "rel_ps_2: " << md::container_to_string(rel_ps_2) << std::endl;
+
+ REQUIRE(c_2_0 == c_2_0_correct);
+ REQUIRE(c_2_1 == c_2_1_correct);
+ REQUIRE(c_2_2 == c_2_2_correct);
+ }
+
+
+}