summaryrefslogtreecommitdiff
path: root/matching
diff options
context:
space:
mode:
Diffstat (limited to 'matching')
-rw-r--r--matching/CMakeLists.txt56
-rw-r--r--matching/README.md69
-rw-r--r--matching/example/matching_dist.cpp366
-rw-r--r--matching/example/module_example.cpp68
-rw-r--r--matching/include/bifiltration.h136
-rw-r--r--matching/include/bifiltration.hpp412
-rw-r--r--matching/include/box.h59
-rw-r--r--matching/include/box.hpp52
-rw-r--r--matching/include/cell_with_value.h141
-rw-r--r--matching/include/cell_with_value.hpp225
-rw-r--r--matching/include/common_defs.h10
-rw-r--r--matching/include/common_util.h183
-rw-r--r--matching/include/common_util.hpp93
-rw-r--r--matching/include/dual_box.h79
-rw-r--r--matching/include/dual_box.hpp188
-rw-r--r--matching/include/dual_point.h107
-rw-r--r--matching/include/dual_point.hpp299
-rw-r--r--matching/include/matching_distance.h290
-rw-r--r--matching/include/matching_distance.hpp722
-rw-r--r--matching/include/opts/opts.h499
-rw-r--r--matching/include/persistence_module.h111
-rw-r--r--matching/include/persistence_module.hpp192
-rw-r--r--matching/include/phat/algorithms/chunk_reduction.h223
-rw-r--r--matching/include/phat/algorithms/row_reduction.h56
-rw-r--r--matching/include/phat/algorithms/spectral_sequence_reduction.h80
-rw-r--r--matching/include/phat/algorithms/standard_reduction.h47
-rw-r--r--matching/include/phat/algorithms/twist_reduction.h51
-rw-r--r--matching/include/phat/boundary_matrix.h343
-rw-r--r--matching/include/phat/compute_persistence_pairs.h137
-rw-r--r--matching/include/phat/helpers/dualize.h74
-rw-r--r--matching/include/phat/helpers/misc.h78
-rw-r--r--matching/include/phat/helpers/thread_local_storage.h52
-rw-r--r--matching/include/phat/persistence_pairs.h155
-rw-r--r--matching/include/phat/representations/abstract_pivot_column.h102
-rw-r--r--matching/include/phat/representations/bit_tree_pivot_column.h165
-rw-r--r--matching/include/phat/representations/full_pivot_column.h100
-rw-r--r--matching/include/phat/representations/heap_pivot_column.h126
-rw-r--r--matching/include/phat/representations/sparse_pivot_column.h79
-rw-r--r--matching/include/phat/representations/vector_heap.h170
-rw-r--r--matching/include/phat/representations/vector_list.h101
-rw-r--r--matching/include/phat/representations/vector_set.h99
-rw-r--r--matching/include/phat/representations/vector_vector.h107
-rw-r--r--matching/include/simplex.h163
-rw-r--r--matching/include/simplex.hpp79
-rw-r--r--matching/tests/prism_1.bif28
-rw-r--r--matching/tests/prism_2.bif29
-rw-r--r--matching/tests/test_bifiltration.cpp36
-rw-r--r--matching/tests/test_common.cpp156
-rw-r--r--matching/tests/test_list.txt1
-rw-r--r--matching/tests/test_matching_distance.cpp159
-rw-r--r--matching/tests/test_module.cpp109
-rw-r--r--matching/tests/tests_main.cpp2
52 files changed, 7464 insertions, 0 deletions
diff --git a/matching/CMakeLists.txt b/matching/CMakeLists.txt
new file mode 100644
index 0000000..a391d84
--- /dev/null
+++ b/matching/CMakeLists.txt
@@ -0,0 +1,56 @@
+project(matching_distance)
+cmake_minimum_required(VERSION 3.5.1)
+
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+
+include(GenerateExportHeader)
+find_package(Boost REQUIRED)
+
+# Default to Release
+
+if (NOT CMAKE_BUILD_TYPE)
+ set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel." FORCE)
+ set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
+endif (NOT CMAKE_BUILD_TYPE)
+
+include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include
+ SYSTEM ${BOOST_INCLUDE_DIR}
+ ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include)
+
+set(CMAKE_CXX_STANDARD 14)
+
+
+if (NOT WIN32)
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS} -Wall -Wextra ")
+ set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -ggdb -D_GLIBCXX_DEBUG")
+ set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELEASE} -O2 -g -ggdb")
+endif (NOT WIN32)
+
+file(GLOB BT_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.hpp)
+file(GLOB MD_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp)
+
+file(GLOB SRC_TEST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/tests/*.cpp)
+
+find_package(Threads)
+set(libraries ${libraries} "stdc++fs" ${CMAKE_THREAD_LIBS_INIT})
+
+
+find_package(OpenMP)
+if (OPENMP_FOUND)
+ set(libraries ${libraries} ${OpenMP_CXX_LIBRARIES})
+ set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
+ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
+ set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
+endif()
+
+
+add_executable(matching_dist "example/matching_dist.cpp" ${MD_HEADERS} ${BT_HEADERS} )
+target_link_libraries(matching_dist PUBLIC ${libraries})
+
+add_executable(module_example "example/module_example.cpp" ${MD_HEADERS} ${BT_HEADERS} )
+target_link_libraries(module_example PUBLIC ${libraries})
+
+
+add_executable(matching_distance_test ${SRC_TEST_FILES} ${BT_HEADERS} ${MD_HEADERS})
+target_link_libraries(matching_distance_test PUBLIC ${libraries})
diff --git a/matching/README.md b/matching/README.md
new file mode 100644
index 0000000..ac3220e
--- /dev/null
+++ b/matching/README.md
@@ -0,0 +1,69 @@
+# Matching distance between bifiltrations and 2-persistence modules.
+
+## Accompanying paper
+M. Kerber, A. Nigmetov,
+Efficient Approximation of the Matching Distance for 2-parameter persistence.
+SoCG 2020.
+
+Bug reports can be sent to "anigmetov EMAIL SIGN lbl DOT gov".
+
+## Dependencies
+
+* Your compiler must support C++11.
+* Boost.
+
+## Usage:
+
+1. To use a standalone command-line utility matching_dist:
+
+`./matching_dist -d dimension -e relative_error file1 file2`
+
+If `relative_error` is not specified, the default 0.1 is used.
+If `dimension` is not specified, the default value is 0.
+Run `./matching_dist` without parameters to see other options.
+
+The output is an approximation of the exact distance (which is assumed to be non-zero).
+Precisely: if *d_exact* is the true distance and *d_approx* is the output, then
+
+> | d_exact - d_approx | / d_exact < relative_error.
+
+Files file1 and file2 must contain 1-critical bi-filtrations in a plain text format which is similar to PHAT. The first line of a file must say *bifiltration_phat_like*. The second line contains the total number of simplices *N*. The next *N* lines contain simplices in the format *dim x y boundary*.
+* *dim*: the dimension of the simplex
+* *x, y*: coordinates of the critical value
+* *boundary*: indices of simplices forming the boundary of the current simplex. Indices are separated by space.
+* Simplices are indexed starting from 0.
+
+For example, the bi-filtration of a segment with vertices appearing at (0,0) and the 1-segment appearing at (3,4) shall be written as:
+
+> bifiltration_phat_like
+> 3
+> \# lines starting with \# are ignored
+> \# vertex A has dimension 0, hence no boundary, its index is 0
+> 0 0 0
+> \# vertex B has index 1
+> 0 0 0
+> \# 1-dimensional simplex {A, B}
+> 1 3 4 0 1
+
+2. To use from your code.
+
+Here you can compute the matching distance either between bi-filtrations or between persistence modules.
+First, you need to include `#include "matching_distance.h"` Practically every class you need is parameterized by Real type, which should be either float or double. The header provides two functions called `matching_distance.`
+See `example/module_example.cpp` for additional details.
+
+## License
+
+See `licence.txt` in the repository root folder.
+
+## Building
+
+CMakeLists.txt can be used to build the command-line utility in the standard
+way. On Linux/Mac/Windows with Cygwin:
+> `mkdir build`
+> `cd build`
+> `cmake ..`
+> `make`
+
+On Windows with Visual Studio: use `cmake-gui` to create the solution in build directory and build it with VS.
+
+The library itself is header-only and does not require separate compilation.
diff --git a/matching/example/matching_dist.cpp b/matching/example/matching_dist.cpp
new file mode 100644
index 0000000..2e0f1eb
--- /dev/null
+++ b/matching/example/matching_dist.cpp
@@ -0,0 +1,366 @@
+#include "common_defs.h"
+
+#include <iostream>
+#include <string>
+#include <cassert>
+#include <experimental/filesystem>
+
+#ifdef MD_EXPERIMENTAL_TIMING
+#include <chrono>
+#endif
+
+#include "opts/opts.h"
+
+#include "matching_distance.h"
+
+using Real = double;
+
+using namespace md;
+
+namespace fs = std::experimental::filesystem;
+
+#ifdef PRINT_HEAT_MAP
+void print_heat_map(const md::HeatMaps<Real>& hms, std::string fname, const CalculationParams<Real>& params)
+{
+ std::set<Real> mu_vals, lambda_vals;
+ auto hm_iter = hms.end();
+ --hm_iter;
+ int max_level = hm_iter->first;
+
+ int level_cardinality = 4;
+ for(int i = 0; i < params.initialization_depth; ++i) {
+ level_cardinality *= 4;
+ }
+ for(int i = params.initialization_depth + 1; i <= max_level; ++i) {
+ assert(static_cast<decltype(level_cardinality)>(hms.at(i).size()) == level_cardinality);
+ level_cardinality *= 4;
+ }
+
+ std::map<std::pair<Real, Real>, Real> hm_x_flat, hm_x_steep, hm_y_flat, hm_y_steep;
+
+ for(const auto& dual_point_value_pair : hms.at(max_level)) {
+ const DualPoint& k = dual_point_value_pair.first;
+ mu_vals.insert(k.mu());
+ lambda_vals.insert(k.lambda());
+ }
+
+ std::vector<Real> lambda_vals_vec(lambda_vals.begin(), lambda_vals.end());
+ std::vector<Real> mu_vals_vec(mu_vals.begin(), mu_vals.end());
+
+ std::ofstream ofs {fname};
+ if (not ofs.good()) {
+ std::cerr << "Cannot write heat map to file " << fname << std::endl;
+ throw std::runtime_error("Cannot open file for writing heat map");
+ }
+
+ std::vector<std::vector<Real>> heatmap_to_print(2 * mu_vals_vec.size(),
+ std::vector<Real>(2 * lambda_vals_vec.size(), 0.0));
+
+ for(auto axis_type : {AxisType::x_type, AxisType::y_type}) {
+ bool is_x_type = axis_type == AxisType::x_type;
+ for(auto angle_type : {AngleType::flat, AngleType::steep}) {
+ bool is_flat = angle_type == AngleType::flat;
+
+ int mu_idx_begin, mu_idx_end;
+
+ if (is_x_type) {
+ mu_idx_begin = mu_vals_vec.size() - 1;
+ mu_idx_end = -1;
+ } else {
+ mu_idx_begin = 0;
+ mu_idx_end = mu_vals_vec.size();
+ }
+
+ int lambda_idx_begin, lambda_idx_end;
+
+ if (is_flat) {
+ lambda_idx_begin = 0;
+ lambda_idx_end = lambda_vals_vec.size();
+ } else {
+ lambda_idx_begin = lambda_vals_vec.size() - 1;
+ lambda_idx_end = -1;
+ }
+
+ int mu_idx_final = is_x_type ? 0 : mu_vals_vec.size();
+
+ for(int mu_idx = mu_idx_begin; mu_idx != mu_idx_end; (mu_idx_begin < mu_idx_end) ? mu_idx++ : mu_idx--) {
+ Real mu = mu_vals_vec.at(mu_idx);
+
+ if (mu == 0.0 and axis_type == AxisType::x_type)
+ continue;
+
+ int lambda_idx_final = is_flat ? 0 : lambda_vals_vec.size();
+
+ for(int lambda_idx = lambda_idx_begin;
+ lambda_idx != lambda_idx_end;
+ (lambda_idx_begin < lambda_idx_end) ? lambda_idx++ : lambda_idx--) {
+
+ Real lambda = lambda_vals_vec.at(lambda_idx);
+
+ if (lambda == 0.0 and angle_type == AngleType::flat)
+ continue;
+
+ DualPoint dp(axis_type, angle_type, lambda, mu);
+ Real dist_value = hms.at(max_level).at(dp);
+
+ heatmap_to_print.at(mu_idx_final).at(lambda_idx_final) = dist_value;
+
+// fmt::print("HM, dp = {}, mu_idx_final = {}, lambda_idx_final = {}, value = {}\n", dp, mu_idx_final,
+// lambda_idx_final, dist_value);
+
+ lambda_idx_final++;
+ }
+ mu_idx_final++;
+ }
+ }
+ }
+
+ for(size_t m_idx = 0; m_idx < heatmap_to_print.size(); ++m_idx) {
+ for(size_t l_idx = 0; l_idx < heatmap_to_print.at(m_idx).size(); ++l_idx) {
+ ofs << heatmap_to_print.at(m_idx).at(l_idx) << " ";
+ }
+ ofs << std::endl;
+ }
+
+ ofs.close();
+}
+#endif
+
+int main(int argc, char** argv)
+{
+ using opts::Option;
+ using opts::PosOption;
+ opts::Options ops;
+
+ CalculationParams<Real> params;
+
+ bool help = false;
+
+#ifdef MD_EXPERIMENTAL_TIMING
+ bool no_stop_asap = false;
+#endif
+
+
+#ifdef PRINT_HEAT_MAP
+ bool heatmap_only = false;
+#endif
+
+ // default values are the best for practical use
+ // if compiled with MD_EXPERIMENTAL_TIMING, user can supply
+ // different bounds and traverse strategies
+ // separated by commas, all combinations will be timed.
+ // See corresponding >> operators in matching_distance.h
+ // for possible values.
+ std::string bounds_list_str = "local_combined";
+ std::string traverse_list_str = "BFS";
+
+ ops >> Option('m', "max-iterations", params.max_depth, "maximal number of iterations (refinements)")
+ >> Option('e', "max-error", params.delta, "error threshold")
+ >> Option('d', "dim", params.dim, "dim")
+ >> Option('i', "initial-depth", params.initialization_depth, "initialization depth")
+#ifdef MD_EXPERIMENTAL_TIMING
+ >> Option("no-stop-asap", no_stop_asap,
+ "don't stop looping over points, if cell cannot be pruned (asap is on by default)")
+ >> Option("bounds", bounds_list_str, "bounds to use, separated by ,")
+ >> Option("traverse", traverse_list_str, "traverse strategy to use, separated by ,")
+#endif
+#ifdef PRINT_HEAT_MAP
+ >> Option("heatmap-only", heatmap_only, "only save heatmap (bruteforce)")
+#endif
+ >> Option('h', "help", help, "show help message");
+
+ std::string fname_a;
+ std::string fname_b;
+
+ if (!ops.parse(argc, argv) || help || !(ops >> PosOption(fname_a) >> PosOption(fname_b))) {
+ std::cerr << "Usage: " << argv[0] << " bifiltration-file-1 bifiltration-file-2\n" << ops << std::endl;
+ return 1;
+ }
+
+#ifdef MD_EXPERIMENTAL_TIMING
+ params.stop_asap = not no_stop_asap;
+#else
+ params.stop_asap = true;
+#endif
+
+ auto bounds_list = split_by_delim(bounds_list_str, ',');
+ auto traverse_list = split_by_delim(traverse_list_str, ',');
+
+ Bifiltration<Real> bif_a(fname_a);
+ Bifiltration<Real> bif_b(fname_b);
+
+ bif_a.sanity_check();
+ bif_b.sanity_check();
+
+ std::vector<BoundStrategy> bound_strategies;
+ std::vector<TraverseStrategy> traverse_strategies;
+
+ for(std::string s : bounds_list) {
+ bound_strategies.push_back(bs_from_string(s));
+ }
+
+ for(std::string s : traverse_list) {
+ traverse_strategies.push_back(ts_from_string(s));
+ }
+
+
+#ifdef MD_EXPERIMENTAL_TIMING
+
+ struct ExperimentResult {
+ CalculationParams<Real> params {CalculationParams<Real>()};
+ int n_hera_calls {0};
+ double total_milliseconds_elapsed {0};
+ Real distance {0};
+ Real actual_error {std::numeric_limits<double>::max()};
+ int actual_max_depth {0};
+
+ int x_wins {0};
+ int y_wins {0};
+ int ad_wins {0};
+
+ int seconds_elapsed() const
+ {
+ return static_cast<int>(total_milliseconds_elapsed / 1000);
+ }
+
+ double savings_ratio_old() const
+ {
+ long int max_possible_calls = 0;
+ long int calls_on_level = 4;
+ for(int i = 0; i <= actual_max_depth; ++i) {
+ max_possible_calls += calls_on_level;
+ calls_on_level *= 4;
+ }
+ return static_cast<double>(n_hera_calls) / static_cast<double>(max_possible_calls);
+ }
+
+ double savings_ratio() const
+ {
+ return static_cast<double>(n_hera_calls) / calls_on_actual_max_depth();
+ }
+
+ long long int calls_on_actual_max_depth() const
+ {
+ long long int result = 1;
+ for(int i = 0; i < actual_max_depth; ++i) {
+ result *= 4;
+ }
+ return result;
+ }
+
+ ExperimentResult() { }
+
+ ExperimentResult(CalculationParams<Real> p, int nhc, double tme, double d)
+ :
+ params(p), n_hera_calls(nhc), total_milliseconds_elapsed(tme), distance(d) { }
+ };
+
+ const int n_repetitions = 1;
+
+#ifdef PRINT_HEAT_MAP
+ if (heatmap_only) {
+ bound_strategies.clear();
+ bound_strategies.push_back(BoundStrategy::bruteforce);
+ traverse_strategies.clear();
+ traverse_strategies.push_back(TraverseStrategy::breadth_first);
+ }
+#endif
+
+ std::map<std::tuple<BoundStrategy, TraverseStrategy>, ExperimentResult> results;
+ for(BoundStrategy bound_strategy : bound_strategies) {
+ for(TraverseStrategy traverse_strategy : traverse_strategies) {
+ CalculationParams<Real> params_experiment;
+ params_experiment.bound_strategy = bound_strategy;
+ params_experiment.traverse_strategy = traverse_strategy;
+ params_experiment.max_depth = params.max_depth;
+ params_experiment.initialization_depth = params.initialization_depth;
+ params_experiment.delta = params.delta;
+ params_experiment.dim = params.dim;
+ params_experiment.hera_epsilon = params.hera_epsilon;
+ params_experiment.stop_asap = params.stop_asap;
+
+ if (traverse_strategy == TraverseStrategy::depth_first and bound_strategy == BoundStrategy::bruteforce)
+ continue;
+
+ // if bruteforce, clamp max iterations number to 7,
+ // save user-provided max_iters in user_max_iters and restore it later.
+ // remember: params is passed by reference to return real relative error and heat map
+
+ int user_max_iters = params.max_depth;
+#ifdef PRINT_HEAT_MAP
+ if (bound_strategy == BoundStrategy::bruteforce and not heatmap_only) {
+ params_experiment.max_depth = std::min(7, params.max_depth);
+ }
+#endif
+ double total_milliseconds_elapsed = 0;
+ int total_n_hera_calls = 0;
+ Real dist;
+ for(int i = 0; i < n_repetitions; ++i) {
+ auto t1 = std::chrono::high_resolution_clock().now();
+ dist = matching_distance(bif_a, bif_b, params_experiment);
+ auto t2 = std::chrono::high_resolution_clock().now();
+ total_milliseconds_elapsed += std::chrono::duration_cast<std::chrono::milliseconds>(
+ t2 - t1).count();
+ total_n_hera_calls += params_experiment.n_hera_calls;
+ }
+
+ auto key = std::make_tuple(bound_strategy, traverse_strategy);
+ results[key].params = params_experiment;
+ results[key].n_hera_calls = total_n_hera_calls / n_repetitions;
+ results[key].total_milliseconds_elapsed = total_milliseconds_elapsed / n_repetitions;
+ results[key].distance = dist;
+ results[key].actual_error = params_experiment.actual_error;
+ results[key].actual_max_depth = params_experiment.actual_max_depth;
+
+ if (bound_strategy == BoundStrategy::bruteforce) { params_experiment.max_depth = user_max_iters; }
+
+#ifdef PRINT_HEAT_MAP
+ if (bound_strategy == BoundStrategy::bruteforce) {
+ fs::path fname_a_path {fname_a.c_str()};
+ fs::path fname_b_path {fname_b.c_str()};
+ fs::path fname_a_wo = fname_a_path.filename();
+ fs::path fname_b_wo = fname_b_path.filename();
+ std::string heat_map_fname = fmt::format("{0}_{1}_dim_{2}_weighted_values_xyp.txt", fname_a_wo.string(),
+ fname_b_wo.string(), params_experiment.dim);
+ fs::path path_hm = fname_a_path.replace_filename(fs::path(heat_map_fname.c_str()));
+ std::cout << "Saving heatmap to " << heat_map_fname;
+ print_heat_map(params_experiment.heat_maps, path_hm.string(), params);
+ }
+#endif
+ }
+ }
+
+// std::cout << "File_1;File_2;Boundstrategy;TraverseStrategy;InitalDepth;NHeraCalls;SavingsRatio;Time;Distance;Error;PushStrategy;MaxDepth;CallsOnMaxDepth;Delta;Dimension" << std::endl;
+ for(auto bs : bound_strategies) {
+ for(auto ts : traverse_strategies) {
+ auto key = std::make_tuple(bs, ts);
+
+ fs::path fname_a_path {fname_a.c_str()};
+ fs::path fname_b_path {fname_b.c_str()};
+ fs::path fname_a_wo = fname_a_path.filename();
+ fs::path fname_b_wo = fname_b_path.filename();
+
+ std::cout << fname_a_wo.string() << ";" << fname_b_wo.string() << ";" << bs << ";" << ts << ";";
+ std::cout << results[key].params.initialization_depth << ";";
+ std::cout << results[key].n_hera_calls << ";"
+ << results[key].savings_ratio() << ";"
+ << results[key].total_milliseconds_elapsed << ";"
+ << results[key].distance << ";"
+ << results[key].actual_error << ";"
+ << "xyp" << ";"
+ << results[key].actual_max_depth << ";"
+ << results[key].calls_on_actual_max_depth() << ";"
+ << params.delta << ";"
+ << params.dim
+ << std::endl;
+ }
+ }
+#else
+ params.bound_strategy = bound_strategies.back();
+ params.traverse_strategy = traverse_strategies.back();
+
+ Real dist = matching_distance<Real>(bif_a, bif_b, params);
+ std::cout << dist << std::endl;
+#endif
+ return 0;
+}
diff --git a/matching/example/module_example.cpp b/matching/example/module_example.cpp
new file mode 100644
index 0000000..aec38c2
--- /dev/null
+++ b/matching/example/module_example.cpp
@@ -0,0 +1,68 @@
+#include <iostream>
+#include "matching_distance.h"
+
+using namespace md;
+
+int main(int /*argc*/, char** /*argv*/)
+{
+ // create generators.
+ // A generator is a point in plane,
+ // generators are stored in a vector of points:
+ PointVec<double> gens_a;
+
+ // module A will have one generator that appears at point (0, 0)
+ gens_a.emplace_back(0, 0);
+
+ // relations are stored in a vector of relations
+ using RelationVec = ModulePresentation<double>::RelVec;
+ RelationVec rels_a;
+
+ // A relation is a struct with position and column
+ using Relation = ModulePresentation<double>::Relation;
+
+ // at this point the relation rel_a_1 will appear:
+ Point<double> rel_a_1_position { 1, 1 };
+
+ // vector IndexVec contains non-zero indices of the corresponding relation
+ // (we work over Z/2). Since we have one generator only, the relation
+ // contains only one entry, 0
+ IndexVec rel_a_1_components { 0 };
+
+ // construct a relation from position and column:
+ Relation rel_a_1 { rel_a_1_position, rel_a_1_components };
+
+ // and add it to a vector of relations
+ rels_a.push_back(rel_a_1);
+
+ // after populating vectors of generators and relations
+ // construct a module:
+ ModulePresentation<double> module_a { gens_a, rels_a };
+
+
+ // same for module_b. It will also have just one
+ // generator and one relation, but at different positions.
+
+ PointVec<double> gens_b;
+ gens_b.emplace_back(1, 1);
+
+ RelationVec rels_b;
+
+ Point<double> rel_b_1_position { 2, 2 };
+ IndexVec rel_b_1_components { 0 };
+
+ rels_b.emplace_back(rel_b_1_position, rel_b_1_components);
+
+ ModulePresentation<double> module_b { gens_b, rels_b };
+
+ // create CalculationParams
+ CalculationParams<double> params;
+ // set relative error to 10 % :
+ params.delta = 0.1;
+ // go at most 8 levels deep in quadtree:
+ params.max_depth = 8;
+
+ double dist = matching_distance(module_a, module_b, params);
+ std::cout << "dist = " << dist << std::endl;
+
+ return 0;
+}
diff --git a/matching/include/bifiltration.h b/matching/include/bifiltration.h
new file mode 100644
index 0000000..5b188d4
--- /dev/null
+++ b/matching/include/bifiltration.h
@@ -0,0 +1,136 @@
+#ifndef MATCHING_DISTANCE_BIFILTRATION_H
+#define MATCHING_DISTANCE_BIFILTRATION_H
+
+#include <string>
+#include <ostream>
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <cassert>
+
+#include "common_util.h"
+#include "box.h"
+#include "simplex.h"
+#include "dual_point.h"
+#include "phat/boundary_matrix.h"
+#include "phat/compute_persistence_pairs.h"
+
+#include "common_util.h"
+
+namespace md {
+
+ template<class Real>
+ class Bifiltration {
+ public:
+ using SimplexVector = std::vector<Simplex<Real>>;
+
+ Bifiltration() = default;
+
+ Bifiltration(const Bifiltration&) = default;
+
+ Bifiltration(Bifiltration&&) = default;
+
+ Bifiltration& operator=(const Bifiltration& other)& = default;
+
+ Bifiltration& operator=(Bifiltration&& other) = default;
+
+ Bifiltration(const std::string& fname); // read from file
+
+ template<class Iter>
+ Bifiltration(Iter begin, Iter end)
+ : simplices_(begin, end)
+ {
+ init();
+ }
+
+ Diagram<Real> weighted_slice_diagram(const DualPoint<Real>& line, int dim) const;
+
+ SimplexVector simplices() const { return simplices_; }
+
+ // translate all points by vector (a,a)
+ void translate(Real a);
+
+ // return minimal value of x- and y-coordinates
+ // among all simplices
+ Real minimal_coordinate() const;
+
+ // return box that contains positions of all simplices
+ Box<Real> bounding_box() const;
+
+ void sanity_check() const;
+
+ int maximal_dim() const { return maximal_dim_; }
+
+ Real max_x() const;
+
+ Real max_y() const;
+
+ Real min_x() const;
+
+ Real min_y() const;
+
+ void add_simplex(Index _id, Point<Real> birth, int _dim, const Column& _bdry);
+
+ void save(const std::string& filename, BifiltrationFormat format = BifiltrationFormat::rivet); // save to file
+
+ void scale(Real lambda);
+
+ private:
+ SimplexVector simplices_;
+
+ Box<Real> bounding_box_;
+ int maximal_dim_ {-1};
+
+ void init();
+
+ void rivet_format_reader(std::ifstream&);
+
+ void phat_like_format_reader(std::ifstream&);
+
+ // in Rene format each simplex knows IDs of its boundary facets
+ // postprocess_phat_like_format fills vector of IDs of boundary facets
+ // in each simplex
+ void postprocess_phat_like_format();
+
+ // in Rivet format each simplex knows its vertices,
+ // postprocess_rivet_format fills vector of IDs of boundary facets
+ // in each simplex
+ void postprocess_rivet_format();
+
+ };
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Bifiltration<Real>& bif);
+
+ template<class Real>
+ class BifiltrationProxy {
+ public:
+ BifiltrationProxy(const Bifiltration<Real>& bif, int dim = 0);
+ // return critical values of simplices that are important for current dimension (dim and dim+1)
+ PointVec<Real> positions() const;
+ // set current dimension
+ int set_dim(int new_dim);
+
+ // wrappers of Bifiltration
+ int maximal_dim() const;
+ void translate(Real a);
+ Real minimal_coordinate() const;
+ Box<Real> bounding_box() const;
+ Real max_x() const;
+ Real max_y() const;
+ Real min_x() const;
+ Real min_y() const;
+ Diagram<Real> weighted_slice_diagram(const DualPoint<Real>& slice) const;
+
+ private:
+ int dim_ { 0 };
+ mutable PointVec<Real> cached_positions_;
+ Bifiltration<Real> bif_;
+
+ void cache_positions() const;
+ };
+}
+
+#include "bifiltration.hpp"
+
+#endif //MATCHING_DISTANCE_BIFILTRATION_H
diff --git a/matching/include/bifiltration.hpp b/matching/include/bifiltration.hpp
new file mode 100644
index 0000000..3d20516
--- /dev/null
+++ b/matching/include/bifiltration.hpp
@@ -0,0 +1,412 @@
+namespace md {
+
+ template<class Real>
+ void Bifiltration<Real>::init()
+ {
+ auto lower_left = max_point<Real>();
+ auto upper_right = min_point<Real>();
+ for(const auto& simplex : simplices_) {
+ lower_left = greatest_lower_bound<>(lower_left, simplex.position());
+ upper_right = least_upper_bound<>(upper_right, simplex.position());
+ maximal_dim_ = std::max(maximal_dim_, simplex.dim());
+ }
+ bounding_box_ = Box<Real>(lower_left, upper_right);
+ }
+
+ template<class Real>
+ Bifiltration<Real>::Bifiltration(const std::string& fname)
+ {
+ std::ifstream ifstr {fname.c_str()};
+ if (!ifstr.good()) {
+ std::string error_message = "Cannot open file " + fname;
+ std::cerr << error_message << std::endl;
+ throw std::runtime_error(error_message);
+ }
+
+ BifiltrationFormat input_format;
+
+ std::string s;
+
+ while(ignore_line(s)) {
+ std::getline(ifstr, s);
+ }
+
+ if (s == "bifiltration") {
+ input_format = BifiltrationFormat::rivet;
+ } else if (s == "bifiltration_phat_like") {
+ input_format = BifiltrationFormat::phat_like;
+ } else {
+ std::cerr << "Unknown format: '" << s << "' in file " << fname << std::endl;
+ throw std::runtime_error("unknown bifiltration format");
+ }
+
+ switch(input_format) {
+ case BifiltrationFormat::rivet :
+ rivet_format_reader(ifstr);
+ break;
+ case BifiltrationFormat::phat_like :
+ phat_like_format_reader(ifstr);
+ break;
+ }
+
+ ifstr.close();
+
+ init();
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::rivet_format_reader(std::ifstream& ifstr)
+ {
+ std::string s;
+ // read axes names, ignore them
+ std::getline(ifstr, s);
+ std::getline(ifstr, s);
+
+ Index index = 0;
+ while(std::getline(ifstr, s)) {
+ if (!ignore_line(s)) {
+ simplices_.emplace_back(index++, s, BifiltrationFormat::rivet);
+ }
+ }
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::phat_like_format_reader(std::ifstream& ifstr)
+ {
+ // read stream line by line; do not use >> operator
+ std::string s;
+ std::getline(ifstr, s);
+
+ // first line contains number of simplices
+ long n_simplices = std::stol(s);
+
+ // all other lines represent a simplex
+ Index index = 0;
+ while(index < n_simplices) {
+ std::getline(ifstr, s);
+ if (!ignore_line(s)) {
+ simplices_.emplace_back(index++, s, BifiltrationFormat::phat_like);
+ }
+ }
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::scale(Real lambda)
+ {
+ for(auto& s : simplices_) {
+ s.scale(lambda);
+ }
+ init();
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::sanity_check() const
+ {
+#ifdef DEBUG
+ // check that boundary has correct number of simplices,
+ // each bounding simplex has correct dim
+ // and appears in the filtration before the simplex it bounds
+ for(const auto& s : simplices_) {
+ assert(s.dim() >= 0);
+ assert(s.dim() == 0 or s.dim() + 1 == (int) s.boundary().size());
+ for(auto bdry_idx : s.boundary()) {
+ Simplex bdry_simplex = simplices()[bdry_idx];
+ assert(bdry_simplex.dim() == s.dim() - 1);
+ assert(bdry_simplex.position().is_less(s.position(), false));
+ }
+ }
+#endif
+ }
+
+ template<class Real>
+ Diagram<Real> Bifiltration<Real>::weighted_slice_diagram(const DualPoint<Real>& line, int dim) const
+ {
+ DiagramKeeper<Real> dgm;
+
+ // make a copy for now; I want slice_diagram to be const
+ std::vector<Simplex<Real>> simplices(simplices_);
+
+// std::vector<Simplex> simplices;
+// simplices.reserve(simplices_.size() / 2);
+// for(const auto& s : simplices_) {
+// if (s.dim() <= dim + 1 and s.dim() >= dim)
+// simplices.emplace_back(s);
+// }
+
+ for(auto& simplex : simplices) {
+ Real value = line.weighted_push(simplex.position());
+ simplex.set_value(value);
+ }
+
+ std::sort(simplices.begin(), simplices.end(),
+ [](const Simplex<Real>& a, const Simplex<Real>& b) { return a.value() < b.value(); });
+ std::map<Index, Index> index_map;
+ for(Index i = 0; i < (int) simplices.size(); i++) {
+ index_map[simplices[i].id()] = i;
+ }
+
+ phat::boundary_matrix<> phat_matrix;
+ phat_matrix.set_num_cols(simplices.size());
+ std::vector<Index> bd_in_slice_filtration;
+ for(Index i = 0; i < (int) simplices.size(); i++) {
+ phat_matrix.set_dim(i, simplices[i].dim());
+ bd_in_slice_filtration.clear();
+ //std::cout << "new col" << i << std::endl;
+ for(int j = 0; j < (int) simplices[i].boundary().size(); j++) {
+ // F[i] contains the indices of its facet wrt to the
+ // original filtration. We have to express it, however,
+ // wrt to the filtration along the slice. That is why
+ // we need the index_map
+ //std::cout << "Found " << F[i].bd[j] << ", returning " << index_map[F[i].bd[j]] << std::endl;
+ bd_in_slice_filtration.push_back(index_map[simplices[i].boundary()[j]]);
+ }
+ std::sort(bd_in_slice_filtration.begin(), bd_in_slice_filtration.end());
+ phat_matrix.set_col(i, bd_in_slice_filtration);
+ }
+ phat::persistence_pairs phat_persistence_pairs;
+ phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix);
+
+ dgm.clear();
+ constexpr Real real_inf = std::numeric_limits<Real>::infinity();
+ for(long i = 0; i < (long) phat_persistence_pairs.get_num_pairs(); i++) {
+ std::pair<phat::index, phat::index> new_pair = phat_persistence_pairs.get_pair(i);
+ bool is_finite_pair = new_pair.second != phat::k_infinity_index;
+ Real birth = simplices.at(new_pair.first).value();
+ Real death = is_finite_pair ? simplices.at(new_pair.second).value() : real_inf;
+ int dim = simplices[new_pair.first].dim();
+ assert(dim + 1 == simplices[new_pair.second].dim());
+ if (birth != death) {
+ dgm.add_point(dim, birth, death);
+ }
+ }
+
+ return dgm.get_diagram(dim);
+ }
+
+ template<class Real>
+ Box<Real> Bifiltration<Real>::bounding_box() const
+ {
+ return bounding_box_;
+ }
+
+ template<class Real>
+ Real Bifiltration<Real>::minimal_coordinate() const
+ {
+ return std::min(bounding_box_.lower_left().x, bounding_box_.lower_left().y);
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::translate(Real a)
+ {
+ bounding_box_.translate(a);
+ for(auto& simplex : simplices_) {
+ simplex.translate(a);
+ }
+ }
+
+ template<class Real>
+ Real Bifiltration<Real>::max_x() const
+ {
+ if (simplices_.empty())
+ return 1;
+ auto me = std::max_element(simplices_.cbegin(), simplices_.cend(),
+ [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; });
+ assert(me != simplices_.cend());
+ return me->position().x;
+ }
+
+ template<class Real>
+ Real Bifiltration<Real>::max_y() const
+ {
+ if (simplices_.empty())
+ return 1;
+ auto me = std::max_element(simplices_.cbegin(), simplices_.cend(),
+ [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; });
+ assert(me != simplices_.cend());
+ return me->position().y;
+ }
+
+ template<class Real>
+ Real Bifiltration<Real>::min_x() const
+ {
+ if (simplices_.empty())
+ return 0;
+ auto me = std::min_element(simplices_.cbegin(), simplices_.cend(),
+ [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; });
+ assert(me != simplices_.cend());
+ return me->position().x;
+ }
+
+ template<class Real>
+ Real Bifiltration<Real>::min_y() const
+ {
+ if (simplices_.empty())
+ return 0;
+ auto me = std::min_element(simplices_.cbegin(), simplices_.cend(),
+ [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; });
+ assert(me != simplices_.cend());
+ return me->position().y;
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::add_simplex(Index _id, Point<Real> birth, int _dim, const Column& _bdry)
+ {
+ simplices_.emplace_back(_id, birth, _dim, _bdry);
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::save(const std::string& filename, md::BifiltrationFormat format)
+ {
+ switch(format) {
+ case BifiltrationFormat::rivet:
+ throw std::runtime_error("Not implemented");
+ break;
+ case BifiltrationFormat::phat_like: {
+ std::ofstream f(filename);
+ if (not f.good()) {
+ std::cerr << "Bifiltration::save: cannot open file " << filename << std::endl;
+ throw std::runtime_error("Cannot open file for writing ");
+ }
+ f << simplices_.size() << "\n";
+
+ for(const auto& s : simplices_) {
+ f << s.dim() << " " << s.position().x << " " << s.position().y << " ";
+ for(int b : s.boundary()) {
+ f << b << " ";
+ }
+ f << std::endl;
+ }
+
+ }
+ break;
+ }
+ }
+
+ template<class Real>
+ void Bifiltration<Real>::postprocess_rivet_format()
+ {
+ std::map<Column, Index> facets_to_ids;
+
+ // fill the map
+ for(Index i = 0; i < (Index) simplices_.size(); ++i) {
+ assert(simplices_[i].id() == i);
+ facets_to_ids[simplices_[i].vertices_] = i;
+ }
+
+// for(const auto& s : simplices_) {
+// facets_to_ids[s] = s.id();
+// }
+
+ // main loop
+ for(auto& s : simplices_) {
+ assert(not s.vertices_.empty());
+ assert(s.facet_indices_.empty());
+ Column facet_indices;
+ for(Index i = 0; i <= s.dim(); ++i) {
+ Column facet;
+ for(Index j : s.vertices_) {
+ if (j != i)
+ facet.push_back(j);
+ }
+ auto facet_index = facets_to_ids.at(facet);
+ facet_indices.push_back(facet_index);
+ }
+ s.facet_indices_ = facet_indices;
+ } // loop over simplices
+ }
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Bifiltration<Real>& bif)
+ {
+ os << "Bifiltration [" << std::endl;
+ for(const auto& s : bif.simplices()) {
+ os << s << std::endl;
+ }
+ os << "]" << std::endl;
+ return os;
+ }
+
+ template<class Real>
+ BifiltrationProxy<Real>::BifiltrationProxy(const Bifiltration<Real>& bif, int dim)
+ :
+ dim_(dim),
+ bif_(bif)
+ {
+ cache_positions();
+ }
+
+ template<class Real>
+ void BifiltrationProxy<Real>::cache_positions() const
+ {
+ cached_positions_.clear();
+ for(const auto& simplex : bif_.simplices()) {
+ if (simplex.dim() == dim_ or simplex.dim() == dim_ + 1)
+ cached_positions_.push_back(simplex.position());
+ }
+ }
+
+ template<class Real>
+ PointVec<Real>
+ BifiltrationProxy<Real>::positions() const
+ {
+ if (cached_positions_.empty()) {
+ cache_positions();
+ }
+ return cached_positions_;
+ }
+
+ // translate all points by vector (a,a)
+ template<class Real>
+ void BifiltrationProxy<Real>::translate(Real a)
+ {
+ bif_.translate(a);
+ }
+
+ // return minimal value of x- and y-coordinates
+ // among all simplices
+ template<class Real>
+ Real BifiltrationProxy<Real>::minimal_coordinate() const
+ {
+ return bif_.minimal_coordinate();
+ }
+
+ // return box that contains positions of all simplices
+ template<class Real>
+ Box<Real> BifiltrationProxy<Real>::bounding_box() const
+ {
+ return bif_.bounding_box();
+ }
+
+ template<class Real>
+ Real BifiltrationProxy<Real>::max_x() const
+ {
+ return bif_.max_x();
+ }
+
+ template<class Real>
+ Real BifiltrationProxy<Real>::max_y() const
+ {
+ return bif_.max_y();
+ }
+
+ template<class Real>
+ Real BifiltrationProxy<Real>::min_x() const
+ {
+ return bif_.min_x();
+ }
+
+ template<class Real>
+ Real BifiltrationProxy<Real>::min_y() const
+ {
+ return bif_.min_y();
+ }
+
+
+ template<class Real>
+ Diagram<Real> BifiltrationProxy<Real>::weighted_slice_diagram(const DualPoint<Real>& slice) const
+ {
+ return bif_.weighted_slice_diagram(slice, dim_);
+ }
+
+}
+
diff --git a/matching/include/box.h b/matching/include/box.h
new file mode 100644
index 0000000..4243667
--- /dev/null
+++ b/matching/include/box.h
@@ -0,0 +1,59 @@
+#ifndef MATCHING_DISTANCE_BOX_H
+#define MATCHING_DISTANCE_BOX_H
+
+#include <cassert>
+#include <limits>
+
+#include "common_util.h"
+
+namespace md {
+
+ template<class Real_>
+ struct Box {
+ public:
+ using Real = Real_;
+ private:
+ Point<Real> ll;
+ Point<Real> ur;
+
+ public:
+ Box(Point<Real> ll = Point<Real>(), Point<Real> ur = Point<Real>())
+ :ll(ll), ur(ur)
+ {
+ }
+
+ Box(Point<Real> center, Real width, Real height) :
+ ll(Point<Real>(center.x - 0.5 * width, center.y - 0.5 * height)),
+ ur(Point<Real>(center.x + 0.5 * width, center.y + 0.5 * height))
+ {
+ }
+
+
+ inline double width() const { return ur.x - ll.x; }
+
+ inline double height() const { return ur.y - ll.y; }
+
+ inline Point<Real> lower_left() const { return ll; }
+ inline Point<Real> upper_right() const { return ur; }
+ inline Point<Real> center() const { return Point<Real>((ll.x + ur.x) / 2, (ll.y + ur.y) / 2); }
+
+ inline bool operator==(const Box& p)
+ {
+ return this->ll == p.ll && this->ur == p.ur;
+ }
+
+ std::vector<Box> refine() const;
+
+ std::vector<Point<Real>> corners() const;
+
+ void translate(Real a);
+ };
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Box<Real>& box);
+
+} // namespace md
+
+#include "box.hpp"
+
+#endif //MATCHING_DISTANCE_BOX_H
diff --git a/matching/include/box.hpp b/matching/include/box.hpp
new file mode 100644
index 0000000..f551d84
--- /dev/null
+++ b/matching/include/box.hpp
@@ -0,0 +1,52 @@
+namespace md {
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Box<Real>& box)
+ {
+ os << "Box(lower_left = " << box.lower_left() << ", upper_right = " << box.upper_right() << ")";
+ return os;
+ }
+
+ template<class Real>
+ void Box<Real>::translate(Real a)
+ {
+ ll.x += a;
+ ll.y += a;
+ ur.x += a;
+ ur.y += a;
+ }
+
+ template<class Real>
+ std::vector<Box<Real>> Box<Real>::refine() const
+ {
+ std::vector<Box<Real>> result;
+
+// 1 | 2
+// 0 | 3
+
+ Point<Real> new_ll = lower_left();
+ Point<Real> new_ur = center();
+ result.emplace_back(new_ll, new_ur);
+
+ new_ll.y = center().y;
+ new_ur.y = ur.y;
+ result.emplace_back(new_ll, new_ur);
+
+ new_ll = center();
+ new_ur = upper_right();
+ result.emplace_back(new_ll, new_ur);
+
+ new_ll.y = ll.y;
+ new_ur.y = center().y;
+ result.emplace_back(new_ll, new_ur);
+
+ return result;
+ }
+
+ template<class Real>
+ std::vector<Point<Real>> Box<Real>::corners() const
+ {
+ return {ll, Point<Real>(ll.x, ur.y), ur, Point<Real>(ur.x, ll.y)};
+ };
+
+}
diff --git a/matching/include/cell_with_value.h b/matching/include/cell_with_value.h
new file mode 100644
index 0000000..3548a11
--- /dev/null
+++ b/matching/include/cell_with_value.h
@@ -0,0 +1,141 @@
+#ifndef MATCHING_DISTANCE_CELL_WITH_VALUE_H
+#define MATCHING_DISTANCE_CELL_WITH_VALUE_H
+
+#include <algorithm>
+
+#include "common_defs.h"
+#include "common_util.h"
+#include "dual_box.h"
+
+namespace md {
+
+ enum class ValuePoint {
+ center,
+ lower_left,
+ lower_right,
+ upper_left,
+ upper_right
+ };
+
+ inline std::ostream& operator<<(std::ostream& os, const ValuePoint& vp)
+ {
+ switch(vp) {
+ case ValuePoint::upper_left :
+ os << "upper_left";
+ break;
+ case ValuePoint::upper_right :
+ os << "upper_right";
+ break;
+ case ValuePoint::lower_left :
+ os << "lower_left";
+ break;
+ case ValuePoint::lower_right :
+ os << "lower_right";
+ break;
+ case ValuePoint::center:
+ os << "center";
+ break;
+ default:
+ os << "FORGOTTEN ValuePoint";
+ }
+ return os;
+ }
+
+ const std::vector<ValuePoint> k_all_vps = {ValuePoint::center, ValuePoint::lower_left, ValuePoint::upper_left,
+ ValuePoint::upper_right, ValuePoint::lower_right};
+
+ const std::vector<ValuePoint> k_corner_vps = {ValuePoint::lower_left, ValuePoint::upper_left,
+ ValuePoint::upper_right, ValuePoint::lower_right};
+
+ // represents a cell in the dual space with the value
+ // of the weighted bottleneck distance
+ template<class Real_>
+ class CellWithValue {
+ public:
+ using Real = Real_;
+
+ CellWithValue() = default;
+
+ CellWithValue(const CellWithValue&) = default;
+
+ CellWithValue(CellWithValue&&) = default;
+
+ CellWithValue& operator=(const CellWithValue& other)& = default;
+
+ CellWithValue& operator=(CellWithValue&& other) = default;
+
+ CellWithValue(const DualBox<Real>& b, int level)
+ :dual_box_(b), level_(level) { }
+
+ DualBox<Real> dual_box() const { return dual_box_; }
+
+ DualPoint<Real> center() const { return dual_box_.center(); }
+
+ Real value_at(ValuePoint vp) const;
+
+ bool has_value_at(ValuePoint vp) const;
+
+ DualPoint<Real> value_point(ValuePoint vp) const;
+
+ int level() const { return level_; }
+
+ void set_value_at(ValuePoint vp, Real new_value);
+
+ bool has_corner_value() const;
+
+ Real stored_upper_bound() const;
+
+ Real max_corner_value() const;
+
+ Real min_value() const;
+
+ bool has_max_possible_value() const { return has_max_possible_value_; }
+
+ std::vector<CellWithValue> get_refined_cells() const;
+
+ void set_max_possible_value(Real new_upper_bound);
+
+ int num_values() const;
+
+#ifdef MD_DEBUG
+ long long int id { 0 };
+
+ static long long int max_id;
+
+ std::vector<long long int> parent_ids;
+#endif
+
+ private:
+
+ bool has_central_value() const { return central_value_ >= 0; }
+
+ bool has_lower_left_value() const { return lower_left_value_ >= 0; }
+
+ bool has_lower_right_value() const { return lower_right_value_ >= 0; }
+
+ bool has_upper_left_value() const { return upper_left_value_ >= 0; }
+
+ bool has_upper_right_value() const { return upper_right_value_ >= 0; }
+
+
+ DualBox<Real> dual_box_;
+ Real central_value_ {-1.0};
+ Real lower_left_value_ {-1.0};
+ Real lower_right_value_ {-1.0};
+ Real upper_left_value_ {-1.0};
+ Real upper_right_value_ {-1.0};
+
+ Real max_possible_value_ {0.0};
+
+ int level_ {0};
+
+ bool has_max_possible_value_ {false};
+ };
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const CellWithValue<Real>& cell);
+} // namespace md
+
+#include "cell_with_value.hpp"
+
+#endif //MATCHING_DISTANCE_CELL_WITH_VALUE_H
diff --git a/matching/include/cell_with_value.hpp b/matching/include/cell_with_value.hpp
new file mode 100644
index 0000000..4079b1b
--- /dev/null
+++ b/matching/include/cell_with_value.hpp
@@ -0,0 +1,225 @@
+namespace md {
+
+#ifdef MD_DEBUG
+ long long int CellWithValue<Real>::max_id = 0;
+#endif
+
+ template<class Real>
+ Real CellWithValue<Real>::value_at(ValuePoint vp) const
+ {
+ switch(vp) {
+ case ValuePoint::upper_left :
+ return upper_left_value_;
+ case ValuePoint::upper_right :
+ return upper_right_value_;
+ case ValuePoint::lower_left :
+ return lower_left_value_;
+ case ValuePoint::lower_right :
+ return lower_right_value_;
+ case ValuePoint::center:
+ return central_value_;
+ }
+ // to shut up compiler warning
+ return 1.0 / 0.0;
+ }
+
+ template<class Real>
+ bool CellWithValue<Real>::has_value_at(ValuePoint vp) const
+ {
+ switch(vp) {
+ case ValuePoint::upper_left :
+ return upper_left_value_ >= 0;
+ case ValuePoint::upper_right :
+ return upper_right_value_ >= 0;
+ case ValuePoint::lower_left :
+ return lower_left_value_ >= 0;
+ case ValuePoint::lower_right :
+ return lower_right_value_ >= 0;
+ case ValuePoint::center:
+ return central_value_ >= 0;
+ }
+ // to shut up compiler warning
+ return 1.0 / 0.0;
+ }
+
+ template<class Real>
+ DualPoint<Real> CellWithValue<Real>::value_point(md::ValuePoint vp) const
+ {
+ switch(vp) {
+ case ValuePoint::upper_left :
+ return dual_box().upper_left();
+ case ValuePoint::upper_right :
+ return dual_box().upper_right();
+ case ValuePoint::lower_left :
+ return dual_box().lower_left();
+ case ValuePoint::lower_right :
+ return dual_box().lower_right();
+ case ValuePoint::center:
+ return dual_box().center();
+ }
+ // to shut up compiler warning
+ return DualPoint<Real>();
+ }
+
+ template<class Real>
+ bool CellWithValue<Real>::has_corner_value() const
+ {
+ return has_lower_left_value() or has_lower_right_value() or has_upper_left_value()
+ or has_upper_right_value();
+ }
+
+ template<class Real>
+ Real CellWithValue<Real>::stored_upper_bound() const
+ {
+ assert(has_max_possible_value_);
+ return max_possible_value_;
+ }
+
+ template<class Real>
+ Real CellWithValue<Real>::max_corner_value() const
+ {
+ return std::max({lower_left_value_, lower_right_value_, upper_left_value_, upper_right_value_});
+ }
+
+ template<class Real>
+ Real CellWithValue<Real>::min_value() const
+ {
+ Real result = std::numeric_limits<Real>::max();
+ for(auto vp : k_all_vps) {
+ if (not has_value_at(vp)) {
+ continue;
+ }
+ result = std::min(result, value_at(vp));
+ }
+ return result;
+ }
+
+ template<class Real>
+ std::vector<CellWithValue<Real>> CellWithValue<Real>::get_refined_cells() const
+ {
+ std::vector<CellWithValue<Real>> result;
+ result.reserve(4);
+ for(const auto& refined_box : dual_box_.refine()) {
+
+ CellWithValue<Real> refined_cell(refined_box, level() + 1);
+
+#ifdef MD_DEBUG
+ refined_cell.parent_ids = parent_ids;
+ refined_cell.parent_ids.push_back(id);
+ refined_cell.id = ++max_id;
+#endif
+
+ if (refined_box.lower_left() == dual_box_.lower_left()) {
+ // _|_
+ // H|_
+
+ refined_cell.set_value_at(ValuePoint::lower_left, lower_left_value_);
+ refined_cell.set_value_at(ValuePoint::upper_right, central_value_);
+
+ } else if (refined_box.upper_right() == dual_box_.upper_right()) {
+ // _|H
+ // _|_
+
+ refined_cell.set_value_at(ValuePoint::lower_left, central_value_);
+ refined_cell.set_value_at(ValuePoint::upper_right, upper_right_value_);
+
+ } else if (refined_box.lower_right() == dual_box_.lower_right()) {
+ // _|_
+ // _|H
+
+ refined_cell.set_value_at(ValuePoint::lower_right, lower_right_value_);
+ refined_cell.set_value_at(ValuePoint::upper_left, central_value_);
+
+ } else if (refined_box.upper_left() == dual_box_.upper_left()) {
+
+ // H|_
+ // _|_
+
+ refined_cell.set_value_at(ValuePoint::lower_right, central_value_);
+ refined_cell.set_value_at(ValuePoint::upper_left, upper_left_value_);
+ }
+ result.emplace_back(refined_cell);
+ }
+ return result;
+ }
+
+ template<class Real>
+ void CellWithValue<Real>::set_value_at(ValuePoint vp, Real new_value)
+ {
+ if (has_value_at(vp)) {
+ std::cerr << "CellWithValue<Real>: trying to re-assign value, vp = " << vp;
+ }
+
+ switch(vp) {
+ case ValuePoint::upper_left :
+ upper_left_value_ = new_value;
+ break;
+ case ValuePoint::upper_right :
+ upper_right_value_ = new_value;
+ break;
+ case ValuePoint::lower_left :
+ lower_left_value_ = new_value;
+ break;
+ case ValuePoint::lower_right :
+ lower_right_value_ = new_value;
+ break;
+ case ValuePoint::center:
+ central_value_ = new_value;
+ break;
+ }
+ }
+
+ template<class Real>
+ int CellWithValue<Real>::num_values() const
+ {
+ int result = 0;
+ for(ValuePoint vp : k_all_vps) {
+ result += has_value_at(vp);
+ }
+ return result;
+ }
+
+
+ template<class Real>
+ void CellWithValue<Real>::set_max_possible_value(Real new_upper_bound)
+ {
+ assert(new_upper_bound >= central_value_);
+ assert(new_upper_bound >= lower_left_value_);
+ assert(new_upper_bound >= lower_right_value_);
+ assert(new_upper_bound >= upper_left_value_);
+ assert(new_upper_bound >= upper_right_value_);
+ has_max_possible_value_ = true;
+ max_possible_value_ = new_upper_bound;
+ }
+
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const CellWithValue<Real>& cell)
+ {
+ os << "CellWithValue(box = " << cell.dual_box() << ", ";
+
+#ifdef MD_DEBUG
+ os << "id = " << cell.id;
+ if (not cell.parent_ids.empty())
+ os << ", parent_ids = " << container_to_string(cell.parent_ids) << ", ";
+#endif
+
+ for(ValuePoint vp : k_all_vps) {
+ if (cell.has_value_at(vp)) {
+ os << "value = " << cell.value_at(vp);
+ os << ", at " << vp << " " << cell.value_point(vp);
+ }
+ }
+
+ os << ", max_corner_value = ";
+ if (cell.has_max_possible_value()) {
+ os << cell.stored_upper_bound();
+ } else {
+ os << "-";
+ }
+
+ os << ", level = " << cell.level() << ")";
+ return os;
+ }
+
+} // namespace md
diff --git a/matching/include/common_defs.h b/matching/include/common_defs.h
new file mode 100644
index 0000000..8d01325
--- /dev/null
+++ b/matching/include/common_defs.h
@@ -0,0 +1,10 @@
+#ifndef MATCHING_DISTANCE_DEF_DEBUG_H
+#define MATCHING_DISTANCE_DEF_DEBUG_H
+
+//#define MD_EXPERIMENTAL_TIMING
+//#define MD_PRINT_HEAT_MAP
+//#define MD_DEBUG
+//#define MD_DO_CHECKS
+//#define MD_DO_FULL_CHECK
+
+#endif //MATCHING_DISTANCE_DEF_DEBUG_H
diff --git a/matching/include/common_util.h b/matching/include/common_util.h
new file mode 100644
index 0000000..20a151a
--- /dev/null
+++ b/matching/include/common_util.h
@@ -0,0 +1,183 @@
+#ifndef MATCHING_DISTANCE_COMMON_UTIL_H
+#define MATCHING_DISTANCE_COMMON_UTIL_H
+
+#include <cassert>
+#include <vector>
+#include <utility>
+#include <cmath>
+#include <ostream>
+#include <sstream>
+#include <string>
+#include <map>
+#include <functional>
+
+#include "common_defs.h"
+#include "phat/helpers/misc.h"
+
+namespace md {
+
+ using Index = phat::index;
+ using IndexVec = std::vector<Index>;
+
+ //static constexpr Real pi = M_PI;
+
+ using Column = std::vector<Index>;
+
+ template<class Real>
+ struct Point {
+ Real x;
+ Real y;
+
+ Point(Real x = 0, Real y = 0)
+ :x(x), y(y) { }
+
+ inline Real norm() const { return sqrt(x * x + y * y); }
+
+ inline void normalize()
+ {
+ Real nor = norm();
+ x /= nor;
+ y /= nor;
+ }
+
+ inline void translate(Real a)
+ {
+ x += a;
+ y += a;
+ }
+
+ inline bool operator==(const Point& v) const
+ {
+ return this->x == v.x && this->y == v.y;
+ }
+
+ // compare both coordinates, as needed in persistence
+ // do not overload operator<, requirements are not satisfied
+ inline bool is_less(const Point& other, bool strict = false) const
+ {
+ if (x <= other.x and y <= other.y) {
+ if (strict) {
+ return x != other.x or y != other.y;
+ }
+ else {
+ return true;
+ }
+ }
+ return false;
+ }
+ };
+
+
+ template<class Real>
+ using PointVec = std::vector<Point<Real>>;
+
+ template<class Real>
+ Point<Real> operator+(const Point<Real>& u, const Point<Real>& v);
+
+ template<class Real>
+ Point<Real> operator-(const Point<Real>& u, const Point<Real>& v);
+
+
+ template<class Real>
+ Point<Real> least_upper_bound(const Point<Real>& u, const Point<Real>& v);
+
+ template<class Real>
+ Point<Real> greatest_lower_bound(const Point<Real>& u, const Point<Real>& v);
+
+ template<class Real>
+ Point<Real> max_point();
+
+ template<class Real>
+ Point<Real> min_point();
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& ostr, const Point<Real>& vec);
+
+ template<class Real>
+ using DiagramPoint = std::pair<Real, Real>;
+
+ template<class Real>
+ using Diagram = std::vector<DiagramPoint<Real>>;
+
+
+ // to keep diagrams in all dimensions
+ // TODO: store in Hera format?
+ template<class Real>
+ class DiagramKeeper {
+ public:
+
+ DiagramKeeper() { };
+
+ void add_point(int dim, Real birth, Real death);
+
+ Diagram<Real> get_diagram(int dim) const;
+
+ void clear() { data_.clear(); }
+
+ private:
+ std::map<int, Diagram<Real>> data_;
+ };
+
+ template<typename C>
+ std::string container_to_string(const C& cont)
+ {
+ std::stringstream ss;
+ ss << "[";
+ int i = 0;
+ for (const auto& x : cont) {
+ i++;
+ ss << x;
+ if (i != (int) cont.size())
+ ss << ", ";
+ }
+ ss << "]";
+ return ss.str();
+ }
+
+ // return true, if s is empty or starts with # (commented out line)
+ // whitespaces in the beginning of s are ignored
+ inline bool ignore_line(const std::string& s)
+ {
+ for(auto c : s) {
+ if (isspace(c))
+ continue;
+ return (c == '#');
+ }
+ return true;
+ }
+
+
+ // split string by delimeter
+ template<typename Out>
+ void split_by_delim(const std::string &s, char delim, Out result) {
+ std::stringstream ss(s);
+ std::string item;
+ while (std::getline(ss, item, delim)) {
+ *(result++) = item;
+ }
+ }
+
+ inline std::vector<std::string> split_by_delim(const std::string &s, char delim) {
+ std::vector<std::string> elems;
+ split_by_delim(s, delim, std::back_inserter(elems));
+ return elems;
+ }
+}
+
+namespace std {
+ template<class Real>
+ struct hash<md::Point<Real>>
+ {
+ std::size_t operator()(const md::Point<Real>& p) const
+ {
+ auto hx = std::hash<decltype(p.x)>()(p.x);
+ auto hy = std::hash<decltype(p.y)>()(p.y);
+ return hx ^ (hy << 1);
+ }
+ };
+};
+
+#include "common_util.hpp"
+
+
+#endif //MATCHING_DISTANCE_COMMON_UTIL_H
diff --git a/matching/include/common_util.hpp b/matching/include/common_util.hpp
new file mode 100644
index 0000000..9c7df37
--- /dev/null
+++ b/matching/include/common_util.hpp
@@ -0,0 +1,93 @@
+#include <vector>
+#include <utility>
+#include <cmath>
+#include <ostream>
+#include <limits>
+#include <algorithm>
+
+#include <common_util.h>
+
+namespace md {
+
+ template<class Real>
+ Point<Real> operator+(const Point<Real>& u, const Point<Real>& v)
+ {
+ return Point<Real>(u.x + v.x, u.y + v.y);
+ }
+
+ template<class Real>
+ Point<Real> operator-(const Point<Real>& u, const Point<Real>& v)
+ {
+ return Point<Real>(u.x - v.x, u.y - v.y);
+ }
+
+ template<class Real>
+ Point<Real> least_upper_bound(const Point<Real>& u, const Point<Real>& v)
+ {
+ return Point<Real>(std::max(u.x, v.x), std::max(u.y, v.y));
+ }
+
+ template<class Real>
+ Point<Real> greatest_lower_bound(const Point<Real>& u, const Point<Real>& v)
+ {
+ return Point<Real>(std::min(u.x, v.x), std::min(u.y, v.y));
+ }
+
+ template<class Real>
+ Point<Real> max_point()
+ {
+ return Point<Real>(std::numeric_limits<Real>::max(), std::numeric_limits<Real>::min());
+ }
+
+ template<class Real>
+ Point<Real> min_point()
+ {
+ return Point<Real>(-std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::min());
+ }
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& ostr, const Point<Real>& vec)
+ {
+ ostr << "(" << vec.x << ", " << vec.y << ")";
+ return ostr;
+ }
+
+ template<class Real>
+ Real l_infty_norm(const Point<Real>& v)
+ {
+ return std::max(std::abs(v.x), std::abs(v.y));
+ }
+
+ template<class Real>
+ Real l_2_norm(const Point<Real>& v)
+ {
+ return v.norm();
+ }
+
+ template<class Real>
+ Real l_2_dist(const Point<Real>& x, const Point<Real>& y)
+ {
+ return l_2_norm(x - y);
+ }
+
+ template<class Real>
+ Real l_infty_dist(const Point<Real>& x, const Point<Real>& y)
+ {
+ return l_infty_norm(x - y);
+ }
+
+ template<class Real>
+ void DiagramKeeper<Real>::add_point(int dim, Real birth, Real death)
+ {
+ data_[dim].emplace_back(birth, death);
+ }
+
+ template<class Real>
+ Diagram<Real> DiagramKeeper<Real>::get_diagram(int dim) const
+ {
+ if (data_.count(dim) == 1)
+ return data_.at(dim);
+ else
+ return Diagram<Real>();
+ }
+}
diff --git a/matching/include/dual_box.h b/matching/include/dual_box.h
new file mode 100644
index 0000000..4bbb639
--- /dev/null
+++ b/matching/include/dual_box.h
@@ -0,0 +1,79 @@
+#ifndef MATCHING_DISTANCE_DUAL_BOX_H
+#define MATCHING_DISTANCE_DUAL_BOX_H
+
+#include <ostream>
+#include <limits>
+#include <vector>
+#include <random>
+
+#include "common_util.h"
+#include "dual_point.h"
+
+namespace md {
+
+
+ template<class Real>
+ class DualBox {
+ public:
+
+ DualBox(DualPoint<Real> ll, DualPoint<Real> ur);
+
+ DualBox() = default;
+ DualBox(const DualBox&) = default;
+ DualBox(DualBox&&) = default;
+
+ DualBox& operator=(const DualBox& other) & = default;
+ DualBox& operator=(DualBox&& other) = default;
+
+
+ DualPoint<Real> center() const { return midpoint(lower_left_, upper_right_); }
+ DualPoint<Real> lower_left() const { return lower_left_; }
+ DualPoint<Real> upper_right() const { return upper_right_; }
+
+ DualPoint<Real> lower_right() const;
+ DualPoint<Real> upper_left() const;
+
+ AxisType axis_type() const { return lower_left_.axis_type(); }
+ AngleType angle_type() const { return lower_left_.angle_type(); }
+
+ Real mu_min() const { return lower_left_.mu(); }
+ Real mu_max() const { return upper_right_.mu(); }
+ Real lambda_min() const { return lower_left_.lambda(); }
+ Real lambda_max() const { return upper_right_.lambda(); }
+
+ // return true, if all lines in dual_box are flat
+ bool is_flat() const { return upper_right_.is_flat(); }
+ bool is_steep() const { return lower_left_.is_steep(); }
+
+ std::vector<DualBox> refine() const;
+ std::vector<DualPoint<Real>> corners() const;
+ std::vector<DualPoint<Real>> critical_points(const Point<Real>& p) const;
+ // sample n points from the box uniformly; for tests
+ std::vector<DualPoint<Real>> random_points(int n) const;
+
+ // return 2 dual points at the boundary
+ // where push changes from horizontal to vertical
+ std::vector<DualPoint<Real>> push_change_points(const Point<Real>& p) const;
+
+ // check that a has same sign, angles are all flat or all steep
+ bool sanity_check() const;
+ bool contains(const DualPoint<Real>& dp) const;
+
+ bool operator==(const DualBox& other) const;
+
+ private:
+ DualPoint<Real> lower_left_;
+ DualPoint<Real> upper_right_;
+ };
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const DualBox<Real>& db)
+ {
+ os << "DualBox(" << db.lower_left() << ", " << db.upper_right() << ")";
+ return os;
+ }
+}
+
+#include "dual_box.hpp"
+
+#endif //MATCHING_DISTANCE_DUAL_BOX_H
diff --git a/matching/include/dual_box.hpp b/matching/include/dual_box.hpp
new file mode 100644
index 0000000..d7e6deb
--- /dev/null
+++ b/matching/include/dual_box.hpp
@@ -0,0 +1,188 @@
+namespace md {
+
+ template<class Real>
+ DualBox<Real>::DualBox(DualPoint<Real> ll, DualPoint<Real> ur)
+ :lower_left_(ll), upper_right_(ur)
+ {
+ }
+
+ template<class Real>
+ std::vector<DualPoint<Real>> DualBox<Real>::corners() const
+ {
+ return {lower_left_,
+ DualPoint<Real>(axis_type(), angle_type(), lower_left_.lambda(), upper_right_.mu()),
+ upper_right_,
+ DualPoint<Real>(axis_type(), angle_type(), upper_right_.lambda(), lower_left_.mu())};
+ }
+
+ template<class Real>
+ std::vector<DualPoint<Real>> DualBox<Real>::push_change_points(const Point<Real>& p) const
+ {
+ std::vector<DualPoint<Real>> result;
+ result.reserve(2);
+
+ bool is_y_type = lower_left_.is_y_type();
+ bool is_flat = lower_left_.is_flat();
+
+ auto mu_from_lambda = [p, is_y_type, is_flat](Real lambda) {
+ bool is_x_type = not is_y_type, is_steep = not is_flat;
+ if (is_y_type && is_flat) {
+ return p.y - lambda * p.x;
+ } else if (is_y_type && is_steep) {
+ return p.y - p.x / lambda;
+ } else if (is_x_type && is_flat) {
+ return p.x - p.y / lambda;
+ } else if (is_x_type && is_steep) {
+ return p.x - lambda * p.y;
+ }
+ // to shut up compiler warning
+ return static_cast<Real>(1.0 / 0.0);
+ };
+
+ auto lambda_from_mu = [p, is_y_type, is_flat](Real mu) {
+ bool is_x_type = not is_y_type, is_steep = not is_flat;
+ if (is_y_type && is_flat) {
+ return (p.y - mu) / p.x;
+ } else if (is_y_type && is_steep) {
+ return p.x / (p.y - mu);
+ } else if (is_x_type && is_flat) {
+ return p.y / (p.x - mu);
+ } else if (is_x_type && is_steep) {
+ return (p.x - mu) / p.y;
+ }
+ // to shut up compiler warning
+ return static_cast<Real>(1.0 / 0.0);
+ };
+
+ // all inequalities below are strict: equality means it is a corner
+ // && critical_points() returns corners anyway
+
+ Real mu_intersect_min = mu_from_lambda(lambda_min());
+
+ if (mu_min() < mu_intersect_min && mu_intersect_min < mu_max())
+ result.emplace_back(axis_type(), angle_type(), lambda_min(), mu_intersect_min);
+
+ Real mu_intersect_max = mu_from_lambda(lambda_max());
+
+ if (mu_max() < mu_intersect_max && mu_intersect_max < mu_max())
+ result.emplace_back(axis_type(), angle_type(), lambda_max(), mu_intersect_max);
+
+ Real lambda_intersect_min = lambda_from_mu(mu_min());
+
+ if (lambda_min() < lambda_intersect_min && lambda_intersect_min < lambda_max())
+ result.emplace_back(axis_type(), angle_type(), lambda_intersect_min, mu_min());
+
+ Real lambda_intersect_max = lambda_from_mu(mu_max());
+ if (lambda_min() < lambda_intersect_max && lambda_intersect_max < lambda_max())
+ result.emplace_back(axis_type(), angle_type(), lambda_intersect_max, mu_max());
+
+ assert(result.size() <= 2);
+
+ if (result.size() > 2) {
+ throw std::runtime_error("push_change_points returned more than 2 points");
+ }
+
+ return result;
+ }
+
+ template<class Real>
+ std::vector<DualPoint<Real>> DualBox<Real>::critical_points(const Point<Real>& /*p*/) const
+ {
+ // maximal difference is attained at corners
+ return corners();
+// std::vector<DualPoint<Real>> result;
+// result.reserve(6);
+// for(auto dp : corners()) result.push_back(dp);
+// for(auto dp : push_change_points(p)) result.push_back(dp);
+// return result;
+ }
+
+ template<class Real>
+ std::vector<DualPoint<Real>> DualBox<Real>::random_points(int n) const
+ {
+ assert(n >= 0);
+ std::mt19937_64 gen(1);
+ std::vector<DualPoint<Real>> result;
+ result.reserve(n);
+ std::uniform_real_distribution<Real> mu_distr(mu_min(), mu_max());
+ std::uniform_real_distribution<Real> lambda_distr(lambda_min(), lambda_max());
+ for(int i = 0; i < n; ++i) {
+ result.emplace_back(axis_type(), angle_type(), lambda_distr(gen), mu_distr(gen));
+ }
+ return result;
+ }
+
+ template<class Real>
+ bool DualBox<Real>::sanity_check() const
+ {
+ lower_left_.sanity_check();
+ upper_right_.sanity_check();
+
+ if (lower_left_.angle_type() != upper_right_.angle_type())
+ throw std::runtime_error("angle types differ");
+
+ if (lower_left_.axis_type() != upper_right_.axis_type())
+ throw std::runtime_error("axis types differ");
+
+ if (lower_left_.lambda() >= upper_right_.lambda())
+ throw std::runtime_error("lambda of lower_left_ greater than lambda of upper_right ");
+
+ if (lower_left_.mu() >= upper_right_.mu())
+ throw std::runtime_error("mu of lower_left_ greater than mu of upper_right ");
+
+ return true;
+ }
+
+ template<class Real>
+ std::vector<DualBox<Real>> DualBox<Real>::refine() const
+ {
+ std::vector<DualBox<Real>> result;
+
+ result.reserve(4);
+
+ Real lambda_middle = (lower_left().lambda() + upper_right().lambda()) / 2.0;
+ Real mu_middle = (lower_left().mu() + upper_right().mu()) / 2.0;
+
+ DualPoint<Real> refinement_center(axis_type(), angle_type(), lambda_middle, mu_middle);
+
+ result.emplace_back(lower_left_, refinement_center);
+
+ result.emplace_back(DualPoint<Real>(axis_type(), angle_type(), lambda_middle, mu_min()),
+ DualPoint<Real>(axis_type(), angle_type(), lambda_max(), mu_middle));
+
+ result.emplace_back(refinement_center, upper_right_);
+
+ result.emplace_back(DualPoint<Real>(axis_type(), angle_type(), lambda_min(), mu_middle),
+ DualPoint<Real>(axis_type(), angle_type(), lambda_middle, mu_max()));
+ return result;
+ }
+
+ template<class Real>
+ bool DualBox<Real>::operator==(const DualBox& other) const
+ {
+ return lower_left() == other.lower_left() &&
+ upper_right() == other.upper_right();
+ }
+
+ template<class Real>
+ bool DualBox<Real>::contains(const DualPoint<Real>& dp) const
+ {
+ return dp.angle_type() == angle_type() && dp.axis_type() == axis_type() &&
+ mu_max() >= dp.mu() &&
+ mu_min() <= dp.mu() &&
+ lambda_min() <= dp.lambda() &&
+ lambda_max() >= dp.lambda();
+ }
+
+ template<class Real>
+ DualPoint<Real> DualBox<Real>::lower_right() const
+ {
+ return DualPoint<Real>(lower_left_.axis_type(), lower_left_.angle_type(), lambda_max(), mu_min());
+ }
+
+ template<class Real>
+ DualPoint<Real> DualBox<Real>::upper_left() const
+ {
+ return DualPoint<Real>(lower_left_.axis_type(), lower_left_.angle_type(), lambda_min(), mu_max());
+ }
+}
diff --git a/matching/include/dual_point.h b/matching/include/dual_point.h
new file mode 100644
index 0000000..8438860
--- /dev/null
+++ b/matching/include/dual_point.h
@@ -0,0 +1,107 @@
+#ifndef MATCHING_DISTANCE_DUAL_POINT_H
+#define MATCHING_DISTANCE_DUAL_POINT_H
+
+#include <vector>
+#include <ostream>
+#include <tuple>
+
+#include "common_util.h"
+#include "box.h"
+
+namespace md {
+
+ enum class AxisType {
+ x_type, y_type
+ };
+ enum class AngleType {
+ flat, steep
+ };
+
+ // class has two flags of AxisType and AngleType.
+ // ATTENTION. == operator is not overloaded,
+ // so, e.g., line y = x has 4 different non-equal representation.
+ // we are unlikely to ever need this, because 4 cases are
+ // always treated separately.
+ template<class Real_>
+ class DualPoint {
+ public:
+ using Real = Real_;
+
+ DualPoint() = default;
+
+ DualPoint(const DualPoint&) = default;
+
+ DualPoint(DualPoint&&) = default;
+
+ DualPoint& operator=(const DualPoint& other)& = default;
+
+ DualPoint& operator=(DualPoint&& other) = default;
+
+ DualPoint(AxisType axis_type, AngleType angle_type, Real lambda, Real mu);
+
+ Real lambda() const { return lambda_; }
+
+ Real mu() const { return mu_; }
+
+ // angle between line and x-axis
+ Real gamma() const;
+
+ bool is_steep() const { return angle_type_ == AngleType::steep; }
+
+ bool is_flat() const { return angle_type_ == AngleType::flat; }
+
+ bool is_x_type() const { return axis_type_ == AxisType::x_type; }
+
+ bool is_y_type() const { return axis_type_ == AxisType::y_type; }
+
+ bool operator<(const DualPoint& rhs) const;
+
+ AxisType axis_type() const { return axis_type_; }
+ AngleType angle_type() const { return angle_type_; }
+
+ // throw exception, if fields are invalid
+ // return true otherwise
+ bool sanity_check() const;
+
+ Real weighted_push(Point<Real> p) const;
+ Point<Real> push(Point<Real> p) const;
+
+ bool is_horizontal() const;
+ bool is_vertical() const;
+
+ bool goes_below(Point<Real> p) const;
+ bool goes_above(Point<Real> p) const;
+
+ bool contains(Point<Real> p) const;
+
+ Real x_slope() const;
+ Real y_slope() const;
+
+ Real x_intercept() const;
+ Real y_intercept() const;
+
+ Real x_from_y(Real y) const;
+ Real y_from_x(Real x) const;
+
+ Real weight() const;
+
+ bool operator==(const DualPoint& other) const;
+
+ private:
+ AxisType axis_type_ {AxisType::y_type};
+ AngleType angle_type_ {AngleType::flat};
+ // both initial values are invalid: lambda must be between 0 and 1
+ Real lambda_ {-1.0};
+ Real mu_ {-1.0};
+ };
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const DualPoint<Real>& dp);
+
+ template<class Real>
+ DualPoint<Real> midpoint(DualPoint<Real> x, DualPoint<Real> y);
+};
+
+#include "dual_point.hpp"
+
+#endif //MATCHING_DISTANCE_DUAL_POINT_H
diff --git a/matching/include/dual_point.hpp b/matching/include/dual_point.hpp
new file mode 100644
index 0000000..04e25f2
--- /dev/null
+++ b/matching/include/dual_point.hpp
@@ -0,0 +1,299 @@
+namespace md {
+
+ inline std::ostream& operator<<(std::ostream& os, const AxisType& at)
+ {
+ if (at == AxisType::x_type)
+ os << "x-type";
+ else
+ os << "y-type";
+ return os;
+ }
+
+ inline std::ostream& operator<<(std::ostream& os, const AngleType& at)
+ {
+ if (at == AngleType::flat)
+ os << "flat";
+ else
+ os << "steep";
+ return os;
+ }
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const DualPoint<Real>& dp)
+ {
+ os << "Line(" << dp.axis_type() << ", ";
+ os << dp.angle_type() << ", ";
+ os << dp.lambda() << ", ";
+ os << dp.mu() << ", equation: ";
+ if (not dp.is_vertical()) {
+ os << "y = " << dp.y_slope() << " x + " << dp.y_intercept();
+ } else {
+ os << "x = " << dp.x_intercept();
+ }
+ os << " )";
+ return os;
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::operator<(const DualPoint<Real>& rhs) const
+ {
+ return std::tie(axis_type_, angle_type_, lambda_, mu_)
+ < std::tie(rhs.axis_type_, rhs.angle_type_, rhs.lambda_, rhs.mu_);
+ }
+
+ template<class Real>
+ DualPoint<Real>::DualPoint(AxisType axis_type, AngleType angle_type, Real lambda, Real mu)
+ :
+ axis_type_(axis_type),
+ angle_type_(angle_type),
+ lambda_(lambda),
+ mu_(mu)
+ {
+ assert(sanity_check());
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::sanity_check() const
+ {
+ if (lambda_ < 0.0)
+ throw std::runtime_error("Invalid line, negative lambda");
+ if (lambda_ > 1.0)
+ throw std::runtime_error("Invalid line, lambda > 1");
+ if (mu_ < 0.0)
+ throw std::runtime_error("Invalid line, negative mu");
+ return true;
+ }
+
+ template<class Real>
+ Real DualPoint<Real>::gamma() const
+ {
+ if (is_steep())
+ return atan(Real(1.0) / lambda_);
+ else
+ return atan(lambda_);
+ }
+
+ template<class Real>
+ DualPoint<Real> midpoint(DualPoint<Real> x, DualPoint<Real> y)
+ {
+ assert(x.angle_type() == y.angle_type() and x.axis_type() == y.axis_type());
+ Real lambda_mid = (x.lambda() + y.lambda()) / 2;
+ Real mu_mid = (x.mu() + y.mu()) / 2;
+ return DualPoint<Real>(x.axis_type(), x.angle_type(), lambda_mid, mu_mid);
+
+ }
+
+ // return k in the line equation y = kx + b
+ template<class Real>
+ Real DualPoint<Real>::y_slope() const
+ {
+ if (is_flat())
+ return lambda();
+ else
+ return Real(1.0) / lambda();
+ }
+
+ // return k in the line equation x = ky + b
+ template<class Real>
+ Real DualPoint<Real>::x_slope() const
+ {
+ if (is_flat())
+ return Real(1.0) / lambda();
+ else
+ return lambda();
+ }
+
+ // return b in the line equation y = kx + b
+ template<class Real>
+ Real DualPoint<Real>::y_intercept() const
+ {
+ if (is_y_type()) {
+ return mu();
+ } else {
+ // x = x_slope * y + mu = x_slope * (y + mu / x_slope)
+ // x-intercept is -mu/x_slope = -mu * y_slope
+ return -mu() * y_slope();
+ }
+ }
+
+ // return k in the line equation x = ky + b
+ template<class Real>
+ Real DualPoint<Real>::x_intercept() const
+ {
+ if (is_x_type()) {
+ return mu();
+ } else {
+ // y = y_slope * x + mu = y_slope (x + mu / y_slope)
+ // x_intercept is -mu/y_slope = -mu * x_slope
+ return -mu() * x_slope();
+ }
+ }
+
+ template<class Real>
+ Real DualPoint<Real>::x_from_y(Real y) const
+ {
+ if (is_horizontal())
+ throw std::runtime_error("x_from_y called on horizontal line");
+ else
+ return x_slope() * y + x_intercept();
+ }
+
+ template<class Real>
+ Real DualPoint<Real>::y_from_x(Real x) const
+ {
+ if (is_vertical())
+ throw std::runtime_error("x_from_y called on horizontal line");
+ else
+ return y_slope() * x + y_intercept();
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::is_horizontal() const
+ {
+ return is_flat() and lambda() == 0;
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::is_vertical() const
+ {
+ return is_steep() and lambda() == 0;
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::contains(Point<Real> p) const
+ {
+ if (is_vertical())
+ return p.x == x_from_y(p.y);
+ else
+ return p.y == y_from_x(p.x);
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::goes_below(Point<Real> p) const
+ {
+ if (is_vertical())
+ return p.x <= x_from_y(p.y);
+ else
+ return p.y >= y_from_x(p.x);
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::goes_above(Point<Real> p) const
+ {
+ if (is_vertical())
+ return p.x >= x_from_y(p.y);
+ else
+ return p.y <= y_from_x(p.x);
+ }
+
+ template<class Real>
+ Point<Real> DualPoint<Real>::push(Point<Real> p) const
+ {
+ Point<Real> result;
+ // if line is below p, we push horizontally
+ bool horizontal_push = goes_below(p);
+ if (is_x_type()) {
+ if (is_flat()) {
+ if (horizontal_push) {
+ result.x = p.y / lambda() + mu();
+ result.y = p.y;
+ } else {
+ // vertical push
+ result.x = p.x;
+ result.y = lambda() * (p.x - mu());
+ }
+ } else {
+ // steep
+ if (horizontal_push) {
+ result.x = lambda() * p.y + mu();
+ result.y = p.y;
+ } else {
+ // vertical push
+ result.x = p.x;
+ result.y = (p.x - mu()) / lambda();
+ }
+ }
+ } else {
+ // y-type
+ if (is_flat()) {
+ if (horizontal_push) {
+ result.x = (p.y - mu()) / lambda();
+ result.y = p.y;
+ } else {
+ // vertical push
+ result.x = p.x;
+ result.y = lambda() * p.x + mu();
+ }
+ } else {
+ // steep
+ if (horizontal_push) {
+ result.x = (p.y - mu()) * lambda();
+ result.y = p.y;
+ } else {
+ // vertical push
+ result.x = p.x;
+ result.y = p.x / lambda() + mu();
+ }
+ }
+ }
+ return result;
+ }
+
+ template<class Real>
+ Real DualPoint<Real>::weighted_push(Point<Real> p) const
+ {
+ // if line is below p, we push horizontally
+ bool horizontal_push = goes_below(p);
+ if (is_x_type()) {
+ if (is_flat()) {
+ if (horizontal_push) {
+ return p.y;
+ } else {
+ // vertical push
+ return lambda() * (p.x - mu());
+ }
+ } else {
+ // steep
+ if (horizontal_push) {
+ return lambda() * p.y;
+ } else {
+ // vertical push
+ return (p.x - mu());
+ }
+ }
+ } else {
+ // y-type
+ if (is_flat()) {
+ if (horizontal_push) {
+ return p.y - mu();
+ } else {
+ // vertical push
+ return lambda() * p.x;
+ }
+ } else {
+ // steep
+ if (horizontal_push) {
+ return lambda() * (p.y - mu());
+ } else {
+ // vertical push
+ return p.x;
+ }
+ }
+ }
+ }
+
+ template<class Real>
+ bool DualPoint<Real>::operator==(const DualPoint<Real>& other) const
+ {
+ return axis_type() == other.axis_type() and
+ angle_type() == other.angle_type() and
+ mu() == other.mu() and
+ lambda() == other.lambda();
+ }
+
+ template<class Real>
+ Real DualPoint<Real>::weight() const
+ {
+ return lambda_ / sqrt(1 + lambda_ * lambda_);
+ }
+} // namespace md
diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h
new file mode 100644
index 0000000..618330d
--- /dev/null
+++ b/matching/include/matching_distance.h
@@ -0,0 +1,290 @@
+#pragma once
+
+#include <vector>
+#include <limits>
+#include <utility>
+#include <ostream>
+#include <chrono>
+#include <tuple>
+#include <algorithm>
+
+
+#include "common_defs.h"
+#include "cell_with_value.h"
+#include "box.h"
+#include "dual_point.h"
+#include "dual_box.h"
+#include "persistence_module.h"
+#include "bifiltration.h"
+#include "bottleneck.h"
+
+namespace md {
+
+#ifdef MD_PRINT_HEAT_MAP
+ template<class Real>
+ using HeatMap = std::map<DualPoint<Real>, Real>;
+
+ template<class Real>
+ using HeatMaps = std::map<int, HeatMap<Real>>;
+#endif
+
+ enum class BoundStrategy {
+ bruteforce,
+ local_dual_bound,
+ local_dual_bound_refined,
+ local_dual_bound_for_each_point,
+ local_combined
+ };
+
+ enum class TraverseStrategy {
+ depth_first,
+ breadth_first,
+ breadth_first_value,
+ upper_bound
+ };
+
+ inline std::ostream& operator<<(std::ostream& os, const BoundStrategy& s)
+ {
+ switch(s) {
+ case BoundStrategy::bruteforce :
+ os << "bruteforce";
+ break;
+ case BoundStrategy::local_dual_bound :
+ os << "local_grob";
+ break;
+ case BoundStrategy::local_combined :
+ os << "local_combined";
+ break;
+ case BoundStrategy::local_dual_bound_refined :
+ os << "local_refined";
+ break;
+ case BoundStrategy::local_dual_bound_for_each_point :
+ os << "local_for_each_point";
+ break;
+ default:
+ os << "FORGOTTEN BOUND STRATEGY";
+ }
+ return os;
+ }
+
+ inline std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s)
+ {
+ switch(s) {
+ case TraverseStrategy::depth_first :
+ os << "DFS";
+ break;
+ case TraverseStrategy::breadth_first :
+ os << "BFS";
+ break;
+ case TraverseStrategy::breadth_first_value :
+ os << "BFS-VAL";
+ break;
+ case TraverseStrategy::upper_bound :
+ os << "UB";
+ break;
+ default:
+ os << "FORGOTTEN TRAVERSE STRATEGY";
+ }
+ return os;
+ }
+
+ inline std::istream& operator>>(std::istream& is, TraverseStrategy& s)
+ {
+ std::string ss;
+ is >> ss;
+ if (ss == "DFS") {
+ s = TraverseStrategy::depth_first;
+ } else if (ss == "BFS") {
+ s = TraverseStrategy::breadth_first;
+ } else if (ss == "BFS-VAL") {
+ s = TraverseStrategy::breadth_first_value;
+ } else if (ss == "UB") {
+ s = TraverseStrategy::upper_bound;
+ } else {
+ throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY");
+ }
+ return is;
+ }
+
+
+ inline std::istream& operator>>(std::istream& is, BoundStrategy& s)
+ {
+ std::string ss;
+ is >> ss;
+ if (ss == "bruteforce") {
+ s = BoundStrategy::bruteforce;
+ } else if (ss == "local_grob") {
+ s = BoundStrategy::local_dual_bound;
+ } else if (ss == "local_combined") {
+ s = BoundStrategy::local_combined;
+ } else if (ss == "local_refined") {
+ s = BoundStrategy::local_dual_bound_refined;
+ } else if (ss == "local_for_each_point") {
+ s = BoundStrategy::local_dual_bound_for_each_point;
+ } else {
+ throw std::runtime_error("UNKNOWN BOUND STRATEGY");
+ }
+ return is;
+ }
+
+ inline BoundStrategy bs_from_string(std::string s)
+ {
+ std::stringstream ss(s);
+ BoundStrategy result;
+ ss >> result;
+ return result;
+ }
+
+ inline TraverseStrategy ts_from_string(std::string s)
+ {
+ std::stringstream ss(s);
+ TraverseStrategy result;
+ ss >> result;
+ return result;
+ }
+
+ template<class Real>
+ struct CalculationParams {
+ static constexpr int ALL_DIMENSIONS = -1;
+
+ Real hera_epsilon {0.001}; // relative error in hera call
+ Real delta {0.1}; // relative error for matching distance
+ int max_depth {8}; // maximal number of refinenemnts
+ int initialization_depth {2};
+ int dim {0}; // in which dim to calculate the distance; use ALL_DIMENSIONS to get max over all dims
+ BoundStrategy bound_strategy {BoundStrategy::local_combined};
+ TraverseStrategy traverse_strategy {TraverseStrategy::breadth_first};
+ bool tolerate_max_iter_exceeded {false};
+ Real actual_error {std::numeric_limits<Real>::max()};
+ int actual_max_depth {0};
+ int n_hera_calls {0}; // for experiments only; is set in matching_distance function, input value is ignored
+
+ // stop looping over points immediately, if current point's displacement is too large
+ // to prune the cell
+ // if true, cells are pruned immediately, and bounds may increase
+ // (just return something large enough to not prune the cell)
+ bool stop_asap { true };
+
+ // print statistics on each quad-tree level
+ bool print_stats { false };
+
+#ifdef MD_PRINT_HEAT_MAP
+ HeatMaps<Real> heat_maps;
+#endif
+ };
+
+
+ template<class Real_, class DiagramProvider>
+ class DistanceCalculator {
+
+ using Real = Real_;
+ using CellValueVector = std::vector<CellWithValue<Real>>;
+
+ public:
+ DistanceCalculator(const DiagramProvider& a,
+ const DiagramProvider& b,
+ CalculationParams<Real>& params);
+
+ Real distance();
+
+ int get_hera_calls_number() const;
+
+#ifndef MD_TEST_CODE
+ private:
+#endif
+
+ DiagramProvider module_a_;
+ DiagramProvider module_b_;
+
+ CalculationParams<Real>& params_;
+
+ int n_hera_calls_;
+ std::map<int, int> n_hera_calls_per_level_;
+ Real distance_;
+
+ // if calculate_on_intermediate, then weighted distance
+ // will be calculated on centers of each grid in between
+ CellValueVector get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last = true);
+
+ CellValueVector get_initial_dual_grid(Real& lower_bound);
+
+#ifdef MD_PRINT_HEAT_MAP
+ void heatmap_in_dimension(int dim, int depth);
+#endif
+
+ Real get_max_x(int module) const;
+
+ Real get_max_y(int module) const;
+
+ void set_cell_central_value(CellWithValue<Real>& dual_cell);
+
+ Real get_distance();
+
+ Real get_distance_pq();
+
+ Real get_max_possible_value(const CellWithValue<Real>* first_cell_ptr, int n_cells);
+
+ Real get_upper_bound(const CellWithValue<Real>& dual_cell, Real good_enough_upper_bound) const;
+
+ Real get_single_dgm_bound(const CellWithValue<Real>& dual_cell, ValuePoint vp, int module,
+ Real good_enough_value) const;
+
+ // this bound depends only on dual box
+ Real get_local_dual_bound(int module, const DualBox<Real>& dual_box) const;
+
+ Real get_local_dual_bound(const DualBox<Real>& dual_box) const;
+
+ // this bound depends only on dual box, is more accurate
+ Real get_local_refined_bound(int module, const DualBox<Real>& dual_box) const;
+
+ Real get_local_refined_bound(const DualBox<Real>& dual_box) const;
+
+ Real get_good_enough_upper_bound(Real lower_bound) const;
+
+ Real get_max_displacement_single_point(const CellWithValue<Real>& dual_cell, ValuePoint value_point,
+ const Point<Real>& p) const;
+
+ void check_upper_bound(const CellWithValue<Real>& dual_cell) const;
+
+ Real distance_on_line(DualPoint<Real> line);
+ Real distance_on_line_const(DualPoint<Real> line) const;
+
+ Real current_error(Real lower_bound, Real upper_bound);
+ };
+
+ template<class Real>
+ Real matching_distance(const Bifiltration<Real>& bif_a, const Bifiltration<Real>& bif_b,
+ CalculationParams<Real>& params);
+
+ template<class Real>
+ Real matching_distance(const ModulePresentation<Real>& mod_a, const ModulePresentation<Real>& mod_b,
+ CalculationParams<Real>& params);
+
+ // for upper bound experiment
+ struct UbExperimentRecord {
+ double error;
+ double lower_bound;
+ double upper_bound;
+ CellWithValue<double> cell;
+ long long int time;
+ long long int n_hera_calls;
+ };
+
+ inline std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r)
+ {
+ os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound;
+ return os;
+ }
+
+
+ template<class K, class V>
+ void print_map(const std::map<K, V>& dic)
+ {
+ for(const auto kv : dic) {
+ std::cout << kv.first << " -> " << kv.second << "\n";
+ }
+ }
+
+} // namespace md
+
+#include "matching_distance.hpp"
diff --git a/matching/include/matching_distance.hpp b/matching/include/matching_distance.hpp
new file mode 100644
index 0000000..f7f44a5
--- /dev/null
+++ b/matching/include/matching_distance.hpp
@@ -0,0 +1,722 @@
+namespace md {
+
+ template<class R, class T>
+ void DistanceCalculator<R, T>::check_upper_bound(const CellWithValue<R>& dual_cell) const
+ {
+ const int n_samples_lambda = 100;
+ const int n_samples_mu = 100;
+ DualBox<R> db = dual_cell.dual_box();
+ R min_lambda = db.lambda_min();
+ R max_lambda = db.lambda_max();
+ R min_mu = db.mu_min();
+ R max_mu = db.mu_max();
+
+ R h_lambda = (max_lambda - min_lambda) / n_samples_lambda;
+ R h_mu = (max_mu - min_mu) / n_samples_mu;
+ for(int i = 1; i < n_samples_lambda; ++i) {
+ for(int j = 1; j < n_samples_mu; ++j) {
+ R lambda = min_lambda + i * h_lambda;
+ R mu = min_mu + j * h_mu;
+ DualPoint<R> l(db.axis_type(), db.angle_type(), lambda, mu);
+ R other_result = distance_on_line_const(l);
+ R diff = fabs(dual_cell.stored_upper_bound() - other_result);
+ if (other_result > dual_cell.stored_upper_bound()) {
+ throw std::runtime_error("Wrong delta estimate");
+ }
+ }
+ }
+ }
+
+ // for all lines l, l' inside dual box,
+ // find the upper bound on the difference of weighted pushes of p
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_max_displacement_single_point(const CellWithValue<R>& dual_cell, ValuePoint vp,
+ const Point<R>& p) const
+ {
+ assert(p.x >= 0 && p.y >= 0);
+
+#ifdef MD_DEBUG
+ std::vector<long long int> debug_ids = {3, 13, 54, 218, 350, 382, 484, 795, 2040, 8415, 44076};
+ bool debug = false; // std::find(debug_ids.begin(), debug_ids.end(), dual_cell.id) != debug_ids.end();
+#endif
+ DualPoint<R> line = dual_cell.value_point(vp);
+ const R base_value = line.weighted_push(p);
+
+ R result = 0.0;
+ for(DualPoint<R> dp : dual_cell.dual_box().critical_points(p)) {
+ R dp_value = dp.weighted_push(p);
+ result = std::max(result, fabs(base_value - dp_value));
+ }
+
+#ifdef MD_DO_FULL_CHECK
+ auto db = dual_cell.dual_box();
+ std::uniform_real_distribution<R> dlambda(db.lambda_min(), db.lambda_max());
+ std::uniform_real_distribution<R> dmu(db.mu_min(), db.mu_max());
+ std::mt19937 gen(1);
+ for(int i = 0; i < 1000; ++i) {
+ R lambda = dlambda(gen);
+ R mu = dmu(gen);
+ DualPoint<R> dp_random { db.axis_type(), db.angle_type(), lambda, mu };
+ R dp_value = dp_random.weighted_push(p);
+ if (fabs(base_value - dp_value) > result) {
+ throw std::runtime_error("error in get_max_displacement_single_value");
+ }
+ }
+#endif
+
+ return result;
+ }
+
+ template<class R, class T>
+ typename DistanceCalculator<R, T>::CellValueVector DistanceCalculator<R, T>::get_initial_dual_grid(R& lower_bound)
+ {
+ CellValueVector result = get_refined_grid(params_.initialization_depth, false, true);
+
+ lower_bound = -1;
+ for(const auto& dc : result) {
+ lower_bound = std::max(lower_bound, dc.max_corner_value());
+ }
+
+ assert(lower_bound >= 0);
+
+ for(auto& dual_cell : result) {
+ R good_enough_ub = get_good_enough_upper_bound(lower_bound);
+ R max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub);
+ dual_cell.set_max_possible_value(max_value_on_cell);
+
+#ifdef MD_DO_FULL_CHECK
+ check_upper_bound(dual_cell);
+#endif
+ }
+
+ return result;
+ }
+
+ template<class R, class T>
+ typename DistanceCalculator<R, T>::CellValueVector
+ DistanceCalculator<R, T>::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last)
+ {
+ const R y_max = std::max(module_a_.max_y(), module_b_.max_y());
+ const R x_max = std::max(module_a_.max_x(), module_b_.max_x());
+
+ const R lambda_min = 0;
+ const R lambda_max = 1;
+
+ const R mu_min = 0;
+
+ DualBox<R> x_flat(DualPoint<R>(AxisType::x_type, AngleType::flat, lambda_min, mu_min),
+ DualPoint<R>(AxisType::x_type, AngleType::flat, lambda_max, x_max));
+
+ DualBox<R> x_steep(DualPoint<R>(AxisType::x_type, AngleType::steep, lambda_min, mu_min),
+ DualPoint<R>(AxisType::x_type, AngleType::steep, lambda_max, x_max));
+
+ DualBox<R> y_flat(DualPoint<R>(AxisType::y_type, AngleType::flat, lambda_min, mu_min),
+ DualPoint<R>(AxisType::y_type, AngleType::flat, lambda_max, y_max));
+
+ DualBox<R> y_steep(DualPoint<R>(AxisType::y_type, AngleType::steep, lambda_min, mu_min),
+ DualPoint<R>(AxisType::y_type, AngleType::steep, lambda_max, y_max));
+
+ CellWithValue<R> x_flat_cell(x_flat, 0);
+ CellWithValue<R> x_steep_cell(x_steep, 0);
+ CellWithValue<R> y_flat_cell(y_flat, 0);
+ CellWithValue<R> y_steep_cell(y_steep, 0);
+
+ if (init_depth == 0) {
+ DualPoint<R> diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0);
+
+ R diagonal_value = distance_on_line(diagonal_x_flat);
+ n_hera_calls_per_level_[0]++;
+
+ x_flat_cell.set_value_at(ValuePoint::lower_right, diagonal_value);
+ y_flat_cell.set_value_at(ValuePoint::lower_right, diagonal_value);
+ x_steep_cell.set_value_at(ValuePoint::lower_right, diagonal_value);
+ y_steep_cell.set_value_at(ValuePoint::lower_right, diagonal_value);
+ }
+
+#ifdef MD_DEBUG
+ x_flat_cell.id = 1;
+ x_steep_cell.id = 2;
+ y_flat_cell.id = 3;
+ y_steep_cell.id = 4;
+ CellWithValue<R>::max_id = 4;
+#endif
+
+ CellValueVector result {x_flat_cell, x_steep_cell, y_flat_cell, y_steep_cell};
+
+ if (init_depth == 0) {
+ return result;
+ }
+
+ CellValueVector refined_result;
+
+ for(int i = 1; i <= init_depth; ++i) {
+ refined_result.clear();
+ for(const auto& dual_cell : result) {
+ for(auto refined_cell : dual_cell.get_refined_cells()) {
+ // we calculate for init_dept - 1, not init_depth,
+ // because we want the cells to have value at a corner
+ if ((i == init_depth - 1 and calculate_on_last) or calculate_on_intermediate)
+ set_cell_central_value(refined_cell);
+ refined_result.push_back(refined_cell);
+ }
+ }
+ result = std::move(refined_result);
+ }
+ return result;
+ }
+
+ template<class R, class T>
+ DistanceCalculator<R, T>::DistanceCalculator(const T& a,
+ const T& b,
+ CalculationParams<R>& params)
+ :
+ module_a_(a),
+ module_b_(b),
+ params_(params)
+ {
+ // make all coordinates non-negative
+ auto min_coord = std::min(module_a_.minimal_coordinate(),
+ module_b_.minimal_coordinate());
+ if (min_coord < 0) {
+ module_a_.translate(-min_coord);
+ module_b_.translate(-min_coord);
+ }
+
+ assert(std::min({module_a_.min_x(), module_b_.min_x(), module_a_.min_y(),
+ module_b_.min_y()}) >= 0);
+
+ }
+
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_max_x(int module) const
+ {
+ return (module == 0) ? module_a_.max_x() : module_b_.max_x();
+ }
+
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_max_y(int module) const
+ {
+ return (module == 0) ? module_a_.max_y() : module_b_.max_y();
+ }
+
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_local_refined_bound(const DualBox<R>& dual_box) const
+ {
+ return get_local_refined_bound(0, dual_box) + get_local_refined_bound(1, dual_box);
+ }
+
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_local_refined_bound(int module, const DualBox<R>& dual_box) const
+ {
+ R d_lambda = dual_box.lambda_max() - dual_box.lambda_min();
+ R d_mu = dual_box.mu_max() - dual_box.mu_min();
+ R result;
+ if (dual_box.axis_type() == AxisType::x_type) {
+ if (dual_box.is_flat()) {
+ result = dual_box.lambda_max() * d_mu + (get_max_x(module) - dual_box.mu_min()) * d_lambda;
+ } else {
+ result = d_mu + get_max_y(module) * d_lambda;
+ }
+ } else {
+ // y-type
+ if (dual_box.is_flat()) {
+ result = d_mu + get_max_x(module) * d_lambda;
+ } else {
+ // steep
+ result = dual_box.lambda_max() * d_mu + (get_max_y(module) - dual_box.mu_min()) * d_lambda;
+ }
+ }
+ return result;
+ }
+
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_local_dual_bound(int module, const DualBox<R>& dual_box) const
+ {
+ R dlambda = dual_box.lambda_max() - dual_box.lambda_min();
+ R dmu = dual_box.mu_max() - dual_box.mu_min();
+
+ if (dual_box.is_flat()) {
+ return get_max_x(module) * dlambda + dmu;
+ } else {
+ return get_max_y(module) * dlambda + dmu;
+ }
+ }
+
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_local_dual_bound(const DualBox<R>& dual_box) const
+ {
+ return get_local_dual_bound(0, dual_box) + get_local_dual_bound(1, dual_box);
+ }
+
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_upper_bound(const CellWithValue<R>& dual_cell, R good_enough_ub) const
+ {
+ assert(good_enough_ub >= 0);
+
+ switch(params_.bound_strategy) {
+ case BoundStrategy::bruteforce:
+ return std::numeric_limits<R>::max();
+
+ case BoundStrategy::local_dual_bound:
+ return dual_cell.min_value() + get_local_dual_bound(dual_cell.dual_box());
+
+ case BoundStrategy::local_dual_bound_refined:
+ return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
+
+ case BoundStrategy::local_combined: {
+ R cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
+ if (cheap_upper_bound < good_enough_ub) {
+ return cheap_upper_bound;
+ } else {
+ [[fallthrough]];
+ }
+ }
+
+ case BoundStrategy::local_dual_bound_for_each_point: {
+ R result = std::numeric_limits<R>::max();
+ for(ValuePoint vp : k_corner_vps) {
+ if (not dual_cell.has_value_at(vp)) {
+ continue;
+ }
+
+ R base_value = dual_cell.value_at(vp);
+ R bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub);
+
+ if (params_.stop_asap and bound_dgm_a + base_value >= good_enough_ub) {
+ // we want to return a valid upper bound, not just something that will prevent discarding the cell
+ // and we don't want to compute pushes for points in second bifiltration.
+ // so just return a constant time bound
+ return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
+ }
+
+ R bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1,
+ std::max(R(0), good_enough_ub - bound_dgm_a));
+
+ result = std::min(result, base_value + bound_dgm_a + bound_dgm_b);
+
+ if (params_.stop_asap and result < good_enough_ub) {
+ break;
+ }
+ }
+ return result;
+ }
+ }
+ // to suppress compiler warning
+ return std::numeric_limits<R>::max();
+ }
+
+ // find maximal displacement of weighted points of m for all lines in dual_box
+ template<class R, class T>
+ R
+ DistanceCalculator<R, T>::get_single_dgm_bound(const CellWithValue<R>& dual_cell,
+ ValuePoint vp,
+ int module,
+ R good_enough_value) const
+ {
+ R result = 0;
+ Point<R> max_point;
+
+ const T& m = (module == 0) ? module_a_ : module_b_;
+ for(const auto& position : m.positions()) {
+ R x = get_max_displacement_single_point(dual_cell, vp, position);
+
+ if (x > result) {
+ result = x;
+ max_point = position;
+ }
+
+ if (params_.stop_asap and result > good_enough_value) {
+ // we want to return a valid upper bound,
+ // now we just see it is worse than we need, but it may be even more
+ // just return a valid upper bound
+ result = get_local_refined_bound(dual_cell.dual_box());
+ break;
+ }
+ }
+
+ return result;
+ }
+
+ template<class R, class T>
+ R DistanceCalculator<R, T>::distance()
+ {
+ return get_distance_pq();
+ }
+
+ // calculate weighted bottleneneck distance between slices on line
+ // increments hera calls counter
+ template<class R, class T>
+ R DistanceCalculator<R, T>::distance_on_line(DualPoint<R> line)
+ {
+ ++n_hera_calls_;
+ R result = distance_on_line_const(line);
+ return result;
+ }
+
+ template<class R, class T>
+ R DistanceCalculator<R, T>::distance_on_line_const(DualPoint<R> line) const
+ {
+ // TODO: think about this - how to call Hera
+ auto dgm_a = module_a_.weighted_slice_diagram(line);
+ auto dgm_b = module_b_.weighted_slice_diagram(line);
+ R result;
+ if (params_.hera_epsilon > static_cast<R>(0)) {
+ result = hera::bottleneckDistApprox(dgm_a, dgm_b, params_.hera_epsilon) / ( params_.hera_epsilon + 1);
+ } else {
+ result = hera::bottleneckDistExact(dgm_a, dgm_b);
+ }
+ return result;
+ }
+
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_good_enough_upper_bound(R lower_bound) const
+ {
+ R result;
+ // in upper_bound strategy we only prune cells if they cannot improve the lower bound,
+ // otherwise the experiment is supposed to run indefinitely
+ if (params_.traverse_strategy == TraverseStrategy::upper_bound) {
+ result = lower_bound;
+ } else {
+ result = (1.0 + params_.delta) * lower_bound;
+ }
+ return result;
+ }
+
+ // helper function
+ // calculate weighted bt distance on cell center,
+ // assign distance value to cell, keep it in heat_map, and return
+ template<class R, class T>
+ void DistanceCalculator<R, T>::set_cell_central_value(CellWithValue<R>& dual_cell)
+ {
+ DualPoint<R> central_line {dual_cell.center()};
+
+ R new_value = distance_on_line(central_line);
+ n_hera_calls_per_level_[dual_cell.level() + 1]++;
+ dual_cell.set_value_at(ValuePoint::center, new_value);
+ params_.actual_max_depth = std::max(params_.actual_max_depth, dual_cell.level() + 1);
+
+#ifdef PRINT_HEAT_MAP
+ if (params_.bound_strategy == BoundStrategy::bruteforce) {
+ if (dual_cell.level() > params_.initialization_depth + 1
+ and params_.heat_maps[dual_cell.level()].count(dual_cell.center()) > 0) {
+ auto existing = params_.heat_maps[dual_cell.level()].find(dual_cell.center());
+ }
+ assert(dual_cell.level() <= params_.initialization_depth + 1
+ or params_.heat_maps[dual_cell.level()].count(dual_cell.center()) == 0);
+ params_.heat_maps[dual_cell.level()][dual_cell.center()] = new_value;
+ }
+#endif
+ }
+
+ // quick-and-dirty hack to efficiently traverse priority queue with dual cells
+ // returns maximal possible value on all cells in queue
+ // assumes that the underlying container is vector!
+ // cell_ptr: pointer to the first element in queue
+ // n_cells: queue size
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_max_possible_value(const CellWithValue<R>* cell_ptr, int n_cells)
+ {
+ R result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0;
+ for(int i = 0; i < n_cells; ++i, ++cell_ptr) {
+ result = std::max(result, cell_ptr->stored_upper_bound());
+ }
+ return result;
+ }
+
+ // helper function:
+ // return current error from lower and upper bounds
+ // and save it in params_ (hence not const)
+ template<class R, class T>
+ R DistanceCalculator<R, T>::current_error(R lower_bound, R upper_bound)
+ {
+ R current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound
+ : std::numeric_limits<R>::max();
+ params_.actual_error = current_error;
+
+ return current_error;
+ }
+
+ // return matching distance
+ // use priority queue to store dual cells
+ // comparison function depends on the strategies in params_
+ // ressets hera calls counter
+ template<class R, class T>
+ R DistanceCalculator<R, T>::get_distance_pq()
+ {
+ std::map<int, long> n_cells_considered;
+ std::map<int, long> n_cells_pushed_into_queue;
+ long int n_too_deep_cells = 0;
+ std::map<int, long> n_cells_discarded;
+ std::map<int, long> n_cells_pruned;
+
+ std::chrono::high_resolution_clock timer;
+ auto start_time = timer.now();
+
+ n_hera_calls_ = 0;
+ n_hera_calls_per_level_.clear();
+
+
+ // if cell is too deep and is not pushed into queue,
+ // we still need to take its max value into account;
+ // the max over such cells is stored in max_result_on_too_fine_cells
+ R upper_bound_on_deep_cells = -1;
+
+ // user-defined less lambda function
+ // to regulate priority queue depending on strategy
+ auto dual_cell_less = [this](const CellWithValue<R>& a, const CellWithValue<R>& b) {
+
+ int a_level = a.level();
+ int b_level = b.level();
+ R a_value = a.max_corner_value();
+ R b_value = b.max_corner_value();
+ R a_ub = a.stored_upper_bound();
+ R b_ub = b.stored_upper_bound();
+ if (this->params_.traverse_strategy == TraverseStrategy::upper_bound and
+ (not a.has_max_possible_value() or not b.has_max_possible_value())) {
+ throw std::runtime_error("no upper bound on cell");
+ }
+ DualPoint<R> a_lower_left = a.dual_box().lower_left();
+ DualPoint<R> b_lower_left = b.dual_box().lower_left();
+
+ switch(this->params_.traverse_strategy) {
+ // in both breadth_first searches we want coarser cells
+ // to be processed first. Cells with smaller level must be larger,
+ // hence the minus in front of level
+ case TraverseStrategy::breadth_first:
+ return std::make_tuple(-a_level, a_lower_left)
+ < std::make_tuple(-b_level, b_lower_left);
+ case TraverseStrategy::breadth_first_value:
+ return std::make_tuple(-a_level, a_value, a_lower_left)
+ < std::make_tuple(-b_level, b_value, b_lower_left);
+ case TraverseStrategy::depth_first:
+ return std::make_tuple(a_value, a_level, a_lower_left)
+ < std::make_tuple(b_value, b_level, b_lower_left);
+ case TraverseStrategy::upper_bound:
+ return std::make_tuple(a_ub, a_level, a_lower_left)
+ < std::make_tuple(b_ub, b_level, b_lower_left);
+ default:
+ throw std::runtime_error("Forgotten case");
+ }
+ };
+
+ std::priority_queue<CellWithValue<R>, CellValueVector, decltype(dual_cell_less)> dual_cells_queue(
+ dual_cell_less);
+
+ // weighted bt distance on the center of current cell
+ R lower_bound = std::numeric_limits<R>::min();
+
+ // init pq and lower bound
+ for(auto& init_cell : get_initial_dual_grid(lower_bound)) {
+ dual_cells_queue.push(init_cell);
+ }
+
+ R upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size());
+
+ std::vector<UbExperimentRecord> ub_experiment_results;
+
+ while(not dual_cells_queue.empty()) {
+
+ CellWithValue<R> dual_cell = dual_cells_queue.top();
+ dual_cells_queue.pop();
+ assert(dual_cell.has_corner_value()
+ and dual_cell.has_max_possible_value()
+ and dual_cell.max_corner_value() <= upper_bound);
+
+ n_cells_considered[dual_cell.level()]++;
+
+ bool discard_cell = false;
+
+ if (not params_.stop_asap) {
+ // if stop_asap is on, it is safer to never discard a cell
+ if (params_.bound_strategy == BoundStrategy::bruteforce) {
+ discard_cell = false;
+ } else if (params_.traverse_strategy == TraverseStrategy::upper_bound) {
+ discard_cell = (dual_cell.stored_upper_bound() <= lower_bound);
+ } else {
+ discard_cell = (dual_cell.stored_upper_bound() <= (1.0 + params_.delta) * lower_bound);
+ }
+ }
+
+ if (discard_cell) {
+ n_cells_discarded[dual_cell.level()]++;
+ continue;
+ }
+
+ // until now, dual_cell knows its value in one of its corners
+ // new_value will be the weighted distance at its center
+ set_cell_central_value(dual_cell);
+ R new_value = dual_cell.value_at(ValuePoint::center);
+ lower_bound = std::max(new_value, lower_bound);
+
+ assert(upper_bound >= lower_bound);
+
+ if (current_error(lower_bound, upper_bound) < params_.delta) {
+ break;
+ }
+
+ // refine cell and push 4 smaller cells into queue
+ for(auto refined_cell : dual_cell.get_refined_cells()) {
+
+ if (refined_cell.num_values() == 0)
+ throw std::runtime_error("no value on cell");
+
+ // if delta is smaller than good_enough_value, it allows to prune cell
+ R good_enough_ub = get_good_enough_upper_bound(lower_bound);
+
+ // upper bound of the parent holds for refined_cell
+ // and can sometimes be smaller!
+ R upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(),
+ get_upper_bound(refined_cell, good_enough_ub));
+
+ refined_cell.set_max_possible_value(upper_bound_on_refined_cell);
+
+#ifdef MD_DO_FULL_CHECK
+ check_upper_bound(refined_cell);
+#endif
+
+ bool prune_cell = false;
+
+ if (refined_cell.level() <= params_.max_depth) {
+ // cell might be added to queue; if it is not added, its maximal value can be safely ignored
+ if (params_.traverse_strategy == TraverseStrategy::upper_bound) {
+ prune_cell = (refined_cell.stored_upper_bound() <= lower_bound);
+ } else if (params_.bound_strategy != BoundStrategy::bruteforce) {
+ prune_cell = (refined_cell.stored_upper_bound() <= (1.0 + params_.delta) * lower_bound);
+ }
+ if (prune_cell)
+ n_cells_pruned[refined_cell.level()]++;
+// prune_cell = (max_result_on_refined_cell <= lower_bound);
+ } else {
+ // cell is too deep, it won't be added to queue
+ // we must memorize maximal value on this cell, because we won't see it anymore
+ prune_cell = true;
+ if (refined_cell.stored_upper_bound() > (1 + params_.delta) * lower_bound) {
+ n_too_deep_cells++;
+ }
+ upper_bound_on_deep_cells = std::max(upper_bound_on_deep_cells, refined_cell.stored_upper_bound());
+ }
+
+ if (not prune_cell) {
+ n_cells_pushed_into_queue[refined_cell.level()]++;
+ dual_cells_queue.push(refined_cell);
+ }
+ } // end loop over refined cells
+
+ if (dual_cells_queue.empty())
+ upper_bound = std::max(upper_bound, upper_bound_on_deep_cells);
+ else
+ upper_bound = std::max(upper_bound_on_deep_cells,
+ get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size()));
+
+ if (params_.traverse_strategy == TraverseStrategy::upper_bound) {
+ upper_bound = dual_cells_queue.top().stored_upper_bound();
+
+ if (get_hera_calls_number() < 20 || get_hera_calls_number() % 20 == 0) {
+ auto elapsed = timer.now() - start_time;
+ UbExperimentRecord ub_exp_record;
+
+ ub_exp_record.error = current_error(lower_bound, upper_bound);
+ ub_exp_record.lower_bound = lower_bound;
+ ub_exp_record.upper_bound = upper_bound;
+ ub_exp_record.cell = dual_cells_queue.top();
+ ub_exp_record.n_hera_calls = n_hera_calls_;
+ ub_exp_record.time = std::chrono::duration_cast<std::chrono::milliseconds>(elapsed).count();
+
+#ifdef MD_DO_CHECKS
+ if (ub_experiment_results.size() > 0) {
+ auto prev = ub_experiment_results.back();
+ if (upper_bound > prev.upper_bound) {
+ throw std::runtime_error("die");
+ }
+
+ if (lower_bound < prev.lower_bound) {
+ throw std::runtime_error("die");
+ }
+ }
+#endif
+
+ ub_experiment_results.emplace_back(ub_exp_record);
+
+ std::cerr << "[UB_EXPERIMENT]\t" << ub_exp_record << "\n";
+ }
+ }
+
+ assert(upper_bound >= lower_bound);
+
+ if (current_error(lower_bound, upper_bound) < params_.delta) {
+ break;
+ }
+ }
+
+ params_.actual_error = current_error(lower_bound, upper_bound);
+
+ if (n_too_deep_cells > 0) {
+ std::cerr << "Warning: error not guaranteed, there were too deep cells.";
+ std::cerr << " Increase max_depth or delta, actual error = " << params_.actual_error << std::endl;
+ }
+ // otherwise actual_error in params can be larger than delta,
+ // but this is OK
+
+ if (params_.print_stats) {
+ std::cout << "EXIT STATS, cells considered:\n";
+ print_map(n_cells_considered);
+ std::cout << "EXIT STATS, cells discarded:\n";
+ print_map(n_cells_discarded);
+ std::cout << "EXIT STATS, cells pruned:\n";
+ print_map(n_cells_pruned);
+ std::cout << "EXIT STATS, cells pushed:\n";
+ print_map(n_cells_pushed_into_queue);
+ std::cout << "EXIT STATS, hera calls:\n";
+ print_map(n_hera_calls_per_level_);
+ std::cout << "EXIT STATS, too deep cells with high value: " << n_too_deep_cells << "\n";
+ }
+
+ return lower_bound;
+ }
+
+ template<class R, class T>
+ int DistanceCalculator<R, T>::get_hera_calls_number() const
+ {
+ return n_hera_calls_;
+ }
+
+ template<class R>
+ R matching_distance(const Bifiltration<R>& bif_a, const Bifiltration<R>& bif_b,
+ CalculationParams<R>& params)
+ {
+ R result;
+ // compute distance only in one dimension
+ if (params.dim != CalculationParams<R>::ALL_DIMENSIONS) {
+ BifiltrationProxy<R> bifp_a(bif_a, params.dim);
+ BifiltrationProxy<R> bifp_b(bif_b, params.dim);
+ DistanceCalculator<R, BifiltrationProxy<R>> runner(bifp_a, bifp_b, params);
+ result = runner.distance();
+ params.n_hera_calls = runner.get_hera_calls_number();
+ } else {
+ // compute distance in all dimensions, return maximal
+ result = -1;
+ for(int dim = 0; dim < std::max(bif_a.maximal_dim(), bif_b.maximal_dim()); ++dim) {
+ BifiltrationProxy<R> bifp_a(bif_a, params.dim);
+ BifiltrationProxy<R> bifp_b(bif_a, params.dim);
+ DistanceCalculator<R, BifiltrationProxy<R>> runner(bifp_a, bifp_b, params);
+ result = std::max(result, runner.distance());
+ params.n_hera_calls += runner.get_hera_calls_number();
+ }
+ }
+ return result;
+ }
+
+
+ template<class R>
+ R matching_distance(const ModulePresentation<R>& mod_a, const ModulePresentation<R>& mod_b,
+ CalculationParams<R>& params)
+ {
+ DistanceCalculator<R, ModulePresentation<R>> runner(mod_a, mod_b, params);
+ R result = runner.distance();
+ params.n_hera_calls = runner.get_hera_calls_number();
+ return result;
+ }
+} // namespace md
diff --git a/matching/include/opts/opts.h b/matching/include/opts/opts.h
new file mode 100644
index 0000000..1a9bbf7
--- /dev/null
+++ b/matching/include/opts/opts.h
@@ -0,0 +1,499 @@
+/**
+ * Author: Dmitriy Morozov <dmitriy@mrzv.org>
+ * The interface is heavily influenced by GetOptPP (https://code.google.com/p/getoptpp/).
+ * The parsing logic is from ProgramOptions.hxx (https://github.com/Fytch/ProgramOptions.hxx).
+ *
+ * History:
+ * - 2015-06-01: added Traits<...>::type_string() for long, unsigned long
+ * - ...
+ * - 2018-04-27: replace parsing logic with the one from ProgramOptions.hxx to
+ * make the parser compliant with [GNU Program Argument Syntax
+ * Conventions](https://www.gnu.org/software/libc/manual/html_node/Argument-Syntax.html)
+ * - 2018-05-11: add dashed_non_option(), to accept arguments that are negative numbers
+ */
+
+#ifndef OPTS_OPTS_H
+#define OPTS_OPTS_H
+
+#include <iostream>
+#include <sstream>
+#include <string>
+#include <list>
+#include <vector>
+#include <map>
+#include <memory>
+#include <cctype>
+#include <functional>
+
+namespace opts {
+
+// Converters
+template<class T>
+struct Converter
+{
+ Converter() {}
+ static
+ bool convert(const std::string& val, T& res)
+ {
+ std::istringstream iss(val);
+ iss >> res;
+ return !iss.fail() && iss.eof();
+ }
+};
+
+// Type
+template<class T>
+struct Traits
+{
+ static std::string type_string() { return "UNKNOWN TYPE"; }
+};
+
+template<>
+struct Traits<int>
+{
+ static std::string type_string() { return "INT"; }
+};
+
+template<>
+struct Traits<short int>
+{
+ static std::string type_string() { return "SHORT INT"; }
+};
+
+template<>
+struct Traits<long>
+{
+ static std::string type_string() { return "LONG"; }
+};
+
+template<>
+struct Traits<unsigned>
+{
+ static std::string type_string() { return "UNSIGNED INT"; }
+};
+
+template<>
+struct Traits<short unsigned>
+{
+ static std::string type_string() { return "SHORT UNSIGNED INT"; }
+};
+
+template<>
+struct Traits<unsigned long>
+{
+ static std::string type_string() { return "UNSIGNED LONG"; }
+};
+
+template<>
+struct Traits<float>
+{
+ static std::string type_string() { return "FLOAT"; }
+};
+
+template<>
+struct Traits<double>
+{
+ static std::string type_string() { return "DOUBLE"; }
+};
+
+template<>
+struct Traits<std::string>
+{
+ static std::string type_string() { return "STRING"; }
+};
+
+
+struct BasicOption
+{
+ using IsShort = std::function<bool(char)>;
+
+ BasicOption(char s_,
+ std::string l_,
+ std::string default_,
+ std::string type_,
+ std::string help_):
+ s(s_), l(l_), d(default_), t(type_), help(help_) {}
+ virtual ~BasicOption() {}
+
+ int long_size() const { return l.size() + 1 + t.size(); }
+
+ void output(std::ostream& out, int max_long) const
+ {
+ out << " ";
+ if (s)
+ out << '-' << s << ", ";
+ else
+ out << " ";
+
+ out << "--" << l << ' ';
+
+ if (!t.empty())
+ out << t;
+
+ for (int i = long_size(); i < max_long; ++i)
+ out << ' ';
+
+ out << " " << help;
+
+ if (!d.empty())
+ {
+ out << " [default: " << d << "]";
+ }
+ out << '\n';
+ }
+
+ virtual bool flag() const { return false; }
+ virtual bool parse(int argc, char** argv, int& i, int j, IsShort is_short);
+ virtual bool set(std::string arg) =0;
+
+ char s;
+ std::string l;
+ std::string d;
+ std::string t;
+ std::string help;
+};
+
+// Option
+template<class T>
+struct OptionContainer: public BasicOption
+{
+ OptionContainer(char s_,
+ const std::string& l_,
+ T& var_,
+ const std::string& help_,
+ const std::string& type_ = Traits<T>::type_string()):
+ BasicOption(s_, l_, default_value(var_), type_, help_),
+ var(&var_) {}
+
+ static
+ std::string default_value(const T& def)
+ {
+ std::ostringstream oss;
+ oss << def;
+ return oss.str();
+ }
+
+ bool set(std::string s) override { return Converter<T>::convert(s, *var); }
+
+ T* var;
+};
+
+template<>
+struct OptionContainer<bool>: public BasicOption
+{
+ OptionContainer(char s_,
+ const std::string& l_,
+ bool& var_,
+ const std::string& help_):
+ BasicOption(s_, l_, "", "", help_),
+ var(&var_) { *var = false; }
+
+ bool parse(int, char**, int&, int, IsShort) override { *var = true; return true; }
+ bool set(std::string) override { return true; }
+ bool flag() const override { return true; }
+
+ bool* var;
+};
+
+template<class T>
+struct OptionContainer< std::vector<T> >: public BasicOption
+{
+ OptionContainer(char s_,
+ const std::string& l_,
+ std::vector<T>& var_,
+ const std::string& help_,
+ const std::string& type_ = "SEQUENCE"):
+ BasicOption(s_, l_, default_value(var_), type_, help_),
+ var(&var_), first(true) { }
+
+ static
+ std::string default_value(const std::vector<T>& def)
+ {
+ std::ostringstream oss;
+ oss << "(";
+ if (def.size())
+ oss << def[0];
+ for (size_t i = 1; i < def.size(); ++i)
+ oss << ", " << def[i];
+ oss << ")";
+ return oss.str();
+ }
+
+ bool set(std::string s) override
+ {
+ if (first)
+ {
+ var->clear();
+ first = false;
+ }
+
+ T x;
+ bool result = Converter<T>::convert(s,x);
+ var->emplace_back(std::move(x));
+ return result;
+ }
+
+ std::vector<T>* var;
+ mutable bool first;
+};
+
+
+template<class T>
+std::unique_ptr<BasicOption>
+Option(char s, const std::string& l, T& var, const std::string& help) { return std::unique_ptr<BasicOption>{new OptionContainer<T>(s, l, var, help)}; }
+
+template<class T>
+std::unique_ptr<BasicOption>
+Option(char s, const std::string& l, T& var,
+ const std::string& type, const std::string& help) { return std::unique_ptr<BasicOption>{new OptionContainer<T>(s, l, var, help, type)}; }
+
+template<class T>
+std::unique_ptr<BasicOption>
+Option(const std::string& l, T& var, const std::string& help) { return std::unique_ptr<BasicOption>{new OptionContainer<T>(0, l, var, help)}; }
+
+template<class T>
+std::unique_ptr<BasicOption>
+Option(const std::string& l, T& var,
+ const std::string& type, const std::string& help) { return std::unique_ptr<BasicOption>{new OptionContainer<T>(0, l, var, help, type)}; }
+
+// PosOption
+template<class T>
+struct PosOptionContainer
+{
+ PosOptionContainer(T& var_):
+ var(&var_) {}
+
+ bool parse(std::list<std::string>& args) const
+ {
+ if (args.empty())
+ return false;
+
+ bool result = Converter<T>::convert(args.front(), *var);
+ if (!result)
+ std::cerr << "error: failed to parse " << args.front() << '\n';
+ args.pop_front();
+ return result;
+ }
+
+ T* var;
+};
+
+template<class T>
+PosOptionContainer<T>
+PosOption(T& var) { return PosOptionContainer<T>(var); }
+
+
+// Options
+struct Options
+{
+ Options():
+ failed(false) {}
+
+ inline
+ Options& operator>>(std::unique_ptr<BasicOption> opt);
+ template<class T>
+ Options& operator>>(const PosOptionContainer<T>& poc);
+
+ operator bool() { return !failed; }
+
+
+ friend
+ std::ostream&
+ operator<<(std::ostream& out, const Options& ops)
+ {
+ int max_long = 0;
+ for (auto& cur : ops.options)
+ {
+ int cur_long = cur->long_size();
+ if (cur_long > max_long)
+ max_long = cur_long;
+ }
+
+ out << "Options:\n";
+ for (auto& cur : ops.options)
+ cur->output(out, max_long);
+
+ return out;
+ }
+
+ bool parse(int argc, char** argv);
+
+ void unrecognized_option(std::string arg) const
+ {
+ std::cerr << "error: unrecognized option " << arg << '\n';
+ }
+
+ static bool dashed_non_option(char* arg, BasicOption::IsShort is_short)
+ {
+ return arg[ 0 ] == '-'
+ && (std::isdigit(arg[ 1 ]) || arg[ 1 ] == '.')
+ && !is_short(arg[ 1 ]);
+ }
+
+ private:
+ std::list<std::string> args;
+ std::list<std::unique_ptr<BasicOption>> options;
+ bool failed;
+};
+
+bool
+BasicOption::parse(int argc, char** argv, int& i, int j, IsShort is_short)
+{
+ char* argument;
+ char* cur_arg = argv[i];
+ // -v...
+ if (argv[i][j] == '\0')
+ {
+ // -v data
+ if (i + 1 < argc && (argv[i+1][0] != '-' || Options::dashed_non_option(argv[i+1], is_short)))
+ {
+ ++i;
+ argument = argv[i];
+ } else
+ {
+ std::cerr << "error: cannot find the argument; ignoring " << argv[i] << '\n';
+ return false;
+ }
+ } else if (argv[i][j] == '=')
+ {
+ // -v=data
+ argument = &argv[i][j+1];
+ } else if( j == 2 ) { // only for short options
+ // -vdata
+ argument = &argv[i][j];
+ } else
+ {
+ std::cerr << "error: unexpected character \'" << argv[i][j] << "\' ignoring " << argv[i] << '\n';
+ return false;
+ }
+ bool result = set(argument);
+ if (!result)
+ std::cerr << "error: failed to parse " << argument << " in " << cur_arg << '\n';
+ return result;
+}
+
+bool
+Options::parse(int argc, char** argv)
+{
+ std::map<char, BasicOption*> short_opts;
+ std::map<std::string, BasicOption*> long_opts;
+
+ for (auto& opt : options)
+ {
+ if (opt->s)
+ short_opts[opt->s] = opt.get();
+
+ long_opts[opt->l] = opt.get();
+ }
+
+ auto is_short = [&short_opts](char c) -> bool { return short_opts.find(c) != short_opts.end(); };
+
+ for (int i = 1; i < argc; ++i)
+ {
+ if( argv[ i ][ 0 ] == '\0' )
+ continue;
+ if( argv[ i ][ 0 ] != '-' || dashed_non_option(argv[i], is_short))
+ args.push_back(argv[i]);
+ else
+ {
+ // -...
+ if( argv[ i ][ 1 ] == '\0' )
+ {
+ // -
+ args.push_back(argv[i]);
+ } else if( argv[ i ][ 1 ] == '-' )
+ {
+ if( argv[ i ][ 2 ] == '\0' )
+ {
+ // --
+ while( ++i < argc )
+ args.push_back(argv[i]);
+ } else {
+ // --...
+ char* first = &argv[ i ][ 2 ];
+ char* last = first;
+ for(; *last != '=' && *last != '\0'; ++last);
+ if (first == last)
+ {
+ failed = true;
+ unrecognized_option(argv[i]);
+ } else
+ {
+ auto opt_it = long_opts.find(std::string{first,last});
+ if (opt_it == long_opts.end())
+ {
+ failed = true;
+ unrecognized_option(argv[i]);
+ } else
+ {
+ failed |= !opt_it->second->parse(argc, argv, i, last - argv[i], is_short);
+ }
+ }
+ }
+ } else
+ {
+ // -f...
+ auto opt_it = short_opts.find(argv[i][1]);
+ if (opt_it == short_opts.end())
+ {
+ failed = true;
+ unrecognized_option(argv[i]);
+ } else if (opt_it->second->flag())
+ {
+ opt_it->second->parse(argc, argv, i, 0, is_short); // arguments are meaningless; just sets the flag
+
+ // -fgh
+ char c;
+ for(int j = 1; (c = argv[i][j]) != '\0'; ++j)
+ {
+ if (!std::isprint(c) || c == '-')
+ {
+ failed = true;
+ std::cerr << "error: invalid character\'" << c << " ignoring " << &argv[i][j] << '\n';
+ break;
+ }
+ opt_it = short_opts.find(c);
+ if (opt_it == short_opts.end())
+ {
+ failed = true;
+ unrecognized_option("-" + std::string(1, c));
+ continue;
+ }
+ if (!opt_it->second->flag())
+ {
+ failed = true;
+ std::cerr << "error: non-void options not allowed in option packs; ignoring " << c << '\n';
+ continue;
+ }
+ opt_it->second->parse(argc, argv, i, 0, is_short); // arguments are meaningless; just sets the flag
+ }
+ } else
+ {
+ failed |= !opt_it->second->parse(argc, argv, i, 2, is_short);
+ }
+ }
+ }
+ }
+
+ return !failed;
+}
+
+Options&
+Options::operator>>(std::unique_ptr<BasicOption> opt)
+{
+ options.emplace_back(std::move(opt));
+ return *this;
+}
+
+template<class T>
+Options&
+Options::operator>>(const PosOptionContainer<T>& poc)
+{
+ if (!failed)
+ failed = !poc.parse(args);
+ return *this;
+}
+
+}
+
+#endif
diff --git a/matching/include/persistence_module.h b/matching/include/persistence_module.h
new file mode 100644
index 0000000..b68c21e
--- /dev/null
+++ b/matching/include/persistence_module.h
@@ -0,0 +1,111 @@
+#ifndef MATCHING_DISTANCE_PERSISTENCE_MODULE_H
+#define MATCHING_DISTANCE_PERSISTENCE_MODULE_H
+
+#include <cassert>
+#include <vector>
+#include <utility>
+#include <string>
+#include <numeric>
+#include <algorithm>
+#include <unordered_set>
+
+#include "phat/boundary_matrix.h"
+#include "phat/compute_persistence_pairs.h"
+
+#include "common_util.h"
+#include "dual_point.h"
+#include "box.h"
+
+namespace md {
+
+ /* ModulePresentation contains only information needed for matching
+ * distance computation over Z/2.
+ * Generators are represented as points (critical values),
+ * id i of generator g_i = its index in * vector generators_.
+ *
+ * Relations are represented by struct Relation, which has two members:
+ * position_ is a point at which relation appears,
+ * components_ contains indices of generators that sum to zero:
+ * if components_ = [i, ..., j], then the relation is g_i +...+ g_j = 0.
+ *
+ * ModulePresentation has member positions_ that contains all
+ * distinct positions of generators and relations;
+ * this member simplifies computing local linear bound.
+ */
+
+
+ template<class Real>
+ class ModulePresentation {
+ public:
+
+ using RealVec = std::vector<Real>;
+
+ enum Format { rivet_firep };
+
+ struct Relation {
+ Point<Real> position_;
+ IndexVec components_;
+
+ Relation() {}
+ Relation(const Point<Real>& _pos, const IndexVec& _components) : position_(_pos), components_(_components) {}
+
+ Real get_x() const { return position_.x; }
+ Real get_y() const { return position_.y; }
+ };
+
+ using RelVec = std::vector<Relation>;
+
+ ModulePresentation() {}
+
+ ModulePresentation(const PointVec<Real>& _generators, const RelVec& _relations);
+
+ Diagram<Real> weighted_slice_diagram(const DualPoint<Real>& line) const;
+
+ // translate all points by vector (a,a)
+ void translate(Real a);
+
+ // return minimal value of x- and y-coordinates
+ Real minimal_coordinate() const { return std::min(min_x(), min_y()); }
+
+ // return box that contains all positions of all simplices
+ Box<Real> bounding_box() const;
+
+ Real max_x() const { return max_x_; }
+
+ Real max_y() const { return max_y_; }
+
+ Real min_x() const { return min_x_; }
+
+ Real min_y() const { return min_y_; }
+
+ PointVec<Real> positions() const;
+
+#ifndef MD_TEST_CODE
+ private:
+#endif
+
+ PointVec<Real> generators_;
+ std::vector<Relation> relations_;
+ PointVec<Real> positions_;
+
+
+ Real max_x_ { std::numeric_limits<Real>::max() };
+ Real max_y_ { std::numeric_limits<Real>::max() };
+ Real min_x_ { -std::numeric_limits<Real>::max() };
+ Real min_y_ { -std::numeric_limits<Real>::max() };
+ Box<Real> bounding_box_;
+
+ void init_boundaries();
+
+ void project_generators(const DualPoint<Real>& slice, IndexVec& sorted_indices, RealVec& projections) const;
+ void project_relations(const DualPoint<Real>& slice, IndexVec& sorted_indices, RealVec& projections) const;
+
+ void get_slice_projection_matrix(const DualPoint<Real>& slice, phat::boundary_matrix<>& phat_matrix,
+ RealVec& gen_projections, RealVec& rel_projections) const;
+
+ };
+} // namespace md
+
+#include "persistence_module.hpp"
+
+#endif //MATCHING_DISTANCE_PERSISTENCE_MODULE_H
diff --git a/matching/include/persistence_module.hpp b/matching/include/persistence_module.hpp
new file mode 100644
index 0000000..7479e02
--- /dev/null
+++ b/matching/include/persistence_module.hpp
@@ -0,0 +1,192 @@
+namespace md {
+
+ /**
+ *
+ * @param values vector of length n
+ * @return [a_1,...,a_n] such that
+ * 1) values[a_1] <= values[a_2] <= ... <= values[a_n]
+ * 2) a_1,...,a_n is a permutation of 1,..,n
+ */
+
+ template<class T>
+ IndexVec get_sorted_indices(const std::vector<T>& values)
+ {
+ IndexVec result(values.size());
+ std::iota(result.begin(), result.end(), 0);
+ std::sort(result.begin(), result.end(),
+ [&values](size_t a, size_t b) { return values[a] < values[b]; });
+ return result;
+ }
+
+ // helper function to initialize const member positions_ in ModulePresentation
+ template<class Real>
+ PointVec<Real> concat_gen_and_rel_positions(const PointVec<Real>& generators,
+ const typename ModulePresentation<Real>::RelVec& relations)
+ {
+ std::unordered_set<Point<Real>> ps(generators.begin(), generators.end());
+ for(const auto& rel : relations) {
+ ps.insert(rel.position_);
+ }
+ return PointVec<Real>(ps.begin(), ps.end());
+ }
+
+
+ template<class Real>
+ void ModulePresentation<Real>::init_boundaries()
+ {
+ max_x_ = -std::numeric_limits<Real>::max();
+ max_y_ = -std::numeric_limits<Real>::max();
+ min_x_ = std::numeric_limits<Real>::max();
+ min_y_ = std::numeric_limits<Real>::max();
+
+ for(const auto& gen : positions_) {
+ min_x_ = std::min(gen.x, min_x_);
+ min_y_ = std::min(gen.y, min_y_);
+ max_x_ = std::max(gen.x, max_x_);
+ max_y_ = std::max(gen.y, max_y_);
+ }
+
+ bounding_box_ = Box<Real>(Point<Real>(min_x_, min_y_), Point<Real>(max_x_, max_y_));
+ }
+
+
+ template<class Real>
+ ModulePresentation<Real>::ModulePresentation(const PointVec<Real>& _generators, const RelVec& _relations) :
+ generators_(_generators),
+ relations_(_relations)
+ {
+ positions_ = concat_gen_and_rel_positions(generators_, relations_);
+ init_boundaries();
+ }
+
+ template<class Real>
+ void ModulePresentation<Real>::translate(Real a)
+ {
+ for(auto& g : generators_) {
+ g.translate(a);
+ }
+
+ for(auto& r : relations_) {
+ r.position_.translate(a);
+ }
+
+ positions_ = concat_gen_and_rel_positions(generators_, relations_);
+ init_boundaries();
+ }
+
+
+ /**
+ *
+ * @param slice line on which generators are projected
+ * @param sorted_indices [a_1,...,a_n] s.t. wpush(generator[a_1]) <= wpush(generator[a_2]) <= ..
+ * @param projections sorted weighted pushes of generators
+ */
+
+ template<class Real>
+ void ModulePresentation<Real>::project_generators(const DualPoint<Real>& slice,
+ IndexVec& sorted_indices, RealVec& projections) const
+ {
+ size_t num_gens = generators_.size();
+
+ RealVec gen_values;
+ gen_values.reserve(num_gens);
+ for(const auto& pos : generators_) {
+ gen_values.push_back(slice.weighted_push(pos));
+ }
+ sorted_indices = get_sorted_indices(gen_values);
+ projections.clear();
+ projections.reserve(num_gens);
+ for(auto i : sorted_indices) {
+ projections.push_back(gen_values[i]);
+ }
+ }
+
+ template<class Real>
+ void ModulePresentation<Real>::project_relations(const DualPoint<Real>& slice, IndexVec& sorted_rel_indices,
+ RealVec& projections) const
+ {
+
+ size_t num_rels = relations_.size();
+
+ RealVec rel_values;
+ rel_values.reserve(num_rels);
+ for(const auto& rel : relations_) {
+ rel_values.push_back(slice.weighted_push(rel.position_));
+ }
+
+ sorted_rel_indices = get_sorted_indices(rel_values);
+
+ projections.clear();
+ projections.reserve(num_rels);
+ for(auto i : sorted_rel_indices) {
+ projections.push_back(rel_values[i]);
+ }
+ }
+
+
+ template<class Real>
+ void ModulePresentation<Real>::get_slice_projection_matrix(const DualPoint<Real>& slice,
+ phat::boundary_matrix<>& phat_matrix,
+ RealVec& gen_projections, RealVec& rel_projections) const
+ {
+ IndexVec sorted_gen_indices, sorted_rel_indices;
+
+ project_generators(slice, sorted_gen_indices, gen_projections);
+ project_relations(slice, sorted_rel_indices, rel_projections);
+
+ phat_matrix.set_num_cols(relations_.size());
+
+ for(Index i = 0; i < (Index) relations_.size(); i++) {
+ IndexVec current_relation = relations_[sorted_rel_indices[i]].components_;
+ for(auto& j : current_relation) {
+ j = sorted_gen_indices[j];
+ }
+ std::sort(current_relation.begin(), current_relation.end());
+ // modules do not have dimension, set all to 0
+ phat_matrix.set_dim(i, 0);
+ phat_matrix.set_col(i, current_relation);
+ }
+ }
+
+
+ template<class Real>
+ Diagram<Real> ModulePresentation<Real>::weighted_slice_diagram(const DualPoint<Real>& slice) const
+ {
+ RealVec gen_projections, rel_projections;
+ phat::boundary_matrix<> phat_matrix;
+
+ get_slice_projection_matrix(slice, phat_matrix, gen_projections, rel_projections);
+
+ phat::persistence_pairs phat_persistence_pairs;
+ phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix);
+
+ Diagram<Real> dgm;
+
+ constexpr Real real_inf = std::numeric_limits<Real>::infinity();
+
+ for(Index i = 0; i < (Index) phat_persistence_pairs.get_num_pairs(); i++) {
+ std::pair<phat::index, phat::index> new_pair = phat_persistence_pairs.get_pair(i);
+ bool is_finite_pair = new_pair.second != phat::k_infinity_index;
+ Real birth = gen_projections.at(new_pair.first);
+ Real death = is_finite_pair ? rel_projections.at(new_pair.second) : real_inf;
+ if (birth != death) {
+ dgm.emplace_back(birth, death);
+ }
+ }
+
+ return dgm;
+ }
+
+ template<class Real>
+ PointVec<Real> ModulePresentation<Real>::positions() const
+ {
+ return positions_;
+ }
+
+ template<class Real>
+ Box<Real> ModulePresentation<Real>::bounding_box() const
+ {
+ return bounding_box_;
+ }
+
+} // namespace md
diff --git a/matching/include/phat/algorithms/chunk_reduction.h b/matching/include/phat/algorithms/chunk_reduction.h
new file mode 100644
index 0000000..1797023
--- /dev/null
+++ b/matching/include/phat/algorithms/chunk_reduction.h
@@ -0,0 +1,223 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+#include <phat/boundary_matrix.h>
+
+namespace phat {
+ class chunk_reduction {
+ public:
+ enum column_type { GLOBAL
+ , LOCAL_POSITIVE
+ , LOCAL_NEGATIVE };
+
+ public:
+ template< typename Representation >
+ void operator() ( boundary_matrix< Representation >& boundary_matrix ) {
+
+
+ const index nr_columns = boundary_matrix.get_num_cols();
+ if( omp_get_max_threads( ) > nr_columns )
+ omp_set_num_threads( 1 );
+
+ const dimension max_dim = boundary_matrix.get_max_dim();
+
+ std::vector< index > lowest_one_lookup( nr_columns, -1 );
+ std::vector < column_type > column_type( nr_columns, GLOBAL );
+ std::vector< char > is_active( nr_columns, false );
+
+ const index chunk_size = omp_get_max_threads() == 1 ? (index)sqrt( (double)nr_columns ) : nr_columns / omp_get_max_threads();
+
+ std::vector< index > chunk_boundaries;
+ for( index cur_boundary = 0; cur_boundary < nr_columns; cur_boundary += chunk_size )
+ chunk_boundaries.push_back( cur_boundary );
+ chunk_boundaries.push_back( nr_columns );
+
+ for( dimension cur_dim = max_dim; cur_dim >= 1; cur_dim-- ) {
+ // Phase 1: Reduce chunks locally -- 1st pass
+ #pragma omp parallel for schedule( guided, 1 )
+ for( index chunk_id = 0; chunk_id < (index)chunk_boundaries.size() - 1; chunk_id++ )
+ _local_chunk_reduction( boundary_matrix, lowest_one_lookup, column_type, cur_dim,
+ chunk_boundaries[ chunk_id ], chunk_boundaries[ chunk_id + 1 ], chunk_boundaries[ chunk_id ] );
+ boundary_matrix.sync();
+
+ // Phase 1: Reduce chunks locally -- 2nd pass
+ #pragma omp parallel for schedule( guided, 1 )
+ for( index chunk_id = 1; chunk_id < (index)chunk_boundaries.size( ) - 1; chunk_id++ )
+ _local_chunk_reduction( boundary_matrix, lowest_one_lookup, column_type, cur_dim,
+ chunk_boundaries[ chunk_id ], chunk_boundaries[ chunk_id + 1 ], chunk_boundaries[ chunk_id - 1 ] );
+ boundary_matrix.sync( );
+ }
+
+ // get global columns
+ std::vector< index > global_columns;
+ for( index cur_col_idx = 0; cur_col_idx < nr_columns; cur_col_idx++ )
+ if( column_type[ cur_col_idx ] == GLOBAL )
+ global_columns.push_back( cur_col_idx );
+
+ // get active columns
+ #pragma omp parallel for
+ for( index idx = 0; idx < (index)global_columns.size(); idx++ )
+ is_active[ global_columns[ idx ] ] = true;
+ _get_active_columns( boundary_matrix, lowest_one_lookup, column_type, global_columns, is_active );
+
+ // Phase 2+3: Simplify columns and reduce them
+ for( dimension cur_dim = max_dim; cur_dim >= 1; cur_dim-- ) {
+ // Phase 2: Simplify columns
+ std::vector< index > temp_col;
+ #pragma omp parallel for schedule( guided, 1 ), private( temp_col )
+ for( index idx = 0; idx < (index)global_columns.size(); idx++ )
+ if( boundary_matrix.get_dim( global_columns[ idx ] ) == cur_dim )
+ _global_column_simplification( global_columns[ idx ], boundary_matrix, lowest_one_lookup, column_type, is_active, temp_col );
+ boundary_matrix.sync();
+
+ // Phase 3: Reduce columns
+ for( index idx = 0; idx < (index)global_columns.size(); idx++ ) {
+ index cur_col = global_columns[ idx ];
+ if( boundary_matrix.get_dim( cur_col ) == cur_dim && column_type[ cur_col ] == GLOBAL ) {
+ index lowest_one = boundary_matrix.get_max_index( cur_col );
+ while( lowest_one != -1 && lowest_one_lookup[ lowest_one ] != -1 ) {
+ boundary_matrix.add_to( lowest_one_lookup[ lowest_one ], cur_col );
+ lowest_one = boundary_matrix.get_max_index( cur_col );
+ }
+ if( lowest_one != -1 ) {
+ lowest_one_lookup[ lowest_one ] = cur_col;
+ boundary_matrix.clear( lowest_one );
+ }
+ boundary_matrix.finalize( cur_col );
+ }
+ }
+ }
+
+ boundary_matrix.sync();
+ }
+
+ protected:
+ template< typename Representation >
+ void _local_chunk_reduction( boundary_matrix< Representation >& boundary_matrix
+ , std::vector<index>& lowest_one_lookup
+ , std::vector< column_type >& column_type
+ , const dimension cur_dim
+ , const index chunk_begin
+ , const index chunk_end
+ , const index row_begin ) {
+
+ for( index cur_col = chunk_begin; cur_col < chunk_end; cur_col++ ) {
+ if( column_type[ cur_col ] == GLOBAL && boundary_matrix.get_dim( cur_col ) == cur_dim ) {
+ index lowest_one = boundary_matrix.get_max_index( cur_col );
+ while( lowest_one != -1 && lowest_one >= row_begin && lowest_one_lookup[ lowest_one ] != -1 ) {
+ boundary_matrix.add_to( lowest_one_lookup[ lowest_one ], cur_col );
+ lowest_one = boundary_matrix.get_max_index( cur_col );
+ }
+ if( lowest_one >= row_begin ) {
+ lowest_one_lookup[ lowest_one ] = cur_col;
+ column_type[ cur_col ] = LOCAL_NEGATIVE;
+ column_type[ lowest_one ] = LOCAL_POSITIVE;
+ boundary_matrix.clear( lowest_one );
+ boundary_matrix.finalize( cur_col );
+ }
+ }
+ }
+ }
+
+ template< typename Representation >
+ void _get_active_columns( const boundary_matrix< Representation >& boundary_matrix
+ , const std::vector< index >& lowest_one_lookup
+ , const std::vector< column_type >& column_type
+ , const std::vector< index >& global_columns
+ , std::vector< char >& is_active ) {
+
+ const index nr_columns = boundary_matrix.get_num_cols();
+ std::vector< char > finished( nr_columns, false );
+
+ std::vector< std::pair < index, index > > stack;
+ std::vector< index > cur_col_values;
+ #pragma omp parallel for schedule( guided, 1 ), private( stack, cur_col_values )
+ for( index idx = 0; idx < (index)global_columns.size(); idx++ ) {
+ bool pop_next = false;
+ index start_col = global_columns[ idx ];
+ stack.push_back( std::pair< index, index >( start_col, -1 ) );
+ while( !stack.empty() ) {
+ index cur_col = stack.back().first;
+ index prev_col = stack.back().second;
+ if( pop_next ) {
+ stack.pop_back();
+ pop_next = false;
+ if( prev_col != -1 ) {
+ if( is_active[ cur_col ] ) {
+ is_active[ prev_col ] = true;
+ }
+ if( prev_col == stack.back().first ) {
+ finished[ prev_col ] = true;
+ pop_next = true;
+ }
+ }
+ } else {
+ pop_next = true;
+ boundary_matrix.get_col( cur_col, cur_col_values );
+ for( index idx = 0; idx < (index) cur_col_values.size(); idx++ ) {
+ index cur_row = cur_col_values[ idx ];
+ if( ( column_type[ cur_row ] == GLOBAL ) ) {
+ is_active[ cur_col ] = true;
+ } else if( column_type[ cur_row ] == LOCAL_POSITIVE ) {
+ index next_col = lowest_one_lookup[ cur_row ];
+ if( next_col != cur_col && !finished[ cur_col ] ) {
+ stack.push_back( std::make_pair( next_col, cur_col ) );
+ pop_next = false;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ template< typename Representation >
+ void _global_column_simplification( const index col_idx
+ , boundary_matrix< Representation >& boundary_matrix
+ , const std::vector< index >& lowest_one_lookup
+ , const std::vector< column_type >& column_type
+ , const std::vector< char >& is_active
+ , std::vector< index >& temp_col )
+ {
+ temp_col.clear();
+ while( !boundary_matrix.is_empty( col_idx ) ) {
+ index cur_row = boundary_matrix.get_max_index( col_idx );
+ switch( column_type[ cur_row ] ) {
+ case GLOBAL:
+ temp_col.push_back( cur_row );
+ boundary_matrix.remove_max( col_idx );
+ break;
+ case LOCAL_NEGATIVE:
+ boundary_matrix.remove_max( col_idx );
+ break;
+ case LOCAL_POSITIVE:
+ if( is_active[ lowest_one_lookup[ cur_row ] ] )
+ boundary_matrix.add_to( lowest_one_lookup[ cur_row ], col_idx );
+ else
+ boundary_matrix.remove_max( col_idx );
+ break;
+ }
+ }
+ std::reverse( temp_col.begin(), temp_col.end() );
+ boundary_matrix.set_col( col_idx, temp_col );
+ }
+ };
+}
diff --git a/matching/include/phat/algorithms/row_reduction.h b/matching/include/phat/algorithms/row_reduction.h
new file mode 100644
index 0000000..cdd1a8f
--- /dev/null
+++ b/matching/include/phat/algorithms/row_reduction.h
@@ -0,0 +1,56 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+#include <phat/boundary_matrix.h>
+
+namespace phat {
+ class row_reduction {
+ public:
+ template< typename Representation >
+ void operator() ( boundary_matrix< Representation >& boundary_matrix ) {
+
+ const index nr_columns = boundary_matrix.get_num_cols();
+ std::vector< std::vector< index > > lowest_one_lookup( nr_columns );
+
+ for( index cur_col = nr_columns - 1; cur_col >= 0; cur_col-- ) {
+ if( !boundary_matrix.is_empty( cur_col ) )
+ lowest_one_lookup[ boundary_matrix.get_max_index( cur_col ) ].push_back( cur_col );
+
+ if( !lowest_one_lookup[ cur_col ].empty() ) {
+ boundary_matrix.clear( cur_col );
+ boundary_matrix.finalize( cur_col );
+ std::vector< index >& cols_with_cur_lowest = lowest_one_lookup[ cur_col ];
+ index source = *min_element( cols_with_cur_lowest.begin(), cols_with_cur_lowest.end() );
+ for( index idx = 0; idx < (index)cols_with_cur_lowest.size(); idx++ ) {
+ index target = cols_with_cur_lowest[ idx ];
+ if( target != source && !boundary_matrix.is_empty( target ) ) {
+ boundary_matrix.add_to( source, target );
+ if( !boundary_matrix.is_empty( target ) ) {
+ index lowest_one_of_target = boundary_matrix.get_max_index( target );
+ lowest_one_lookup[ lowest_one_of_target ].push_back( target );
+ }
+ }
+ }
+ }
+ }
+ }
+ };
+} \ No newline at end of file
diff --git a/matching/include/phat/algorithms/spectral_sequence_reduction.h b/matching/include/phat/algorithms/spectral_sequence_reduction.h
new file mode 100644
index 0000000..bf442e6
--- /dev/null
+++ b/matching/include/phat/algorithms/spectral_sequence_reduction.h
@@ -0,0 +1,80 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+#include <phat/boundary_matrix.h>
+
+namespace phat {
+ class spectral_sequence_reduction {
+ public:
+ template< typename Representation >
+ void operator () ( boundary_matrix< Representation >& boundary_matrix ) {
+
+ const index nr_columns = boundary_matrix.get_num_cols();
+ std::vector< index > lowest_one_lookup( nr_columns, -1 );
+
+ //const index num_stripes = (index) sqrt( (double)nr_columns );
+ const index num_stripes = omp_get_max_threads();
+
+ index block_size = ( nr_columns % num_stripes == 0 ) ? nr_columns / num_stripes : block_size = nr_columns / num_stripes + 1;
+
+ std::vector< std::vector< index > > unreduced_cols_cur_pass( num_stripes );
+ std::vector< std::vector< index > > unreduced_cols_next_pass( num_stripes );
+
+ for( index cur_dim = boundary_matrix.get_max_dim(); cur_dim >= 1 ; cur_dim-- ) {
+ #pragma omp parallel for schedule( guided, 1 )
+ for( index cur_stripe = 0; cur_stripe < num_stripes; cur_stripe++ ) {
+ index col_begin = cur_stripe * block_size;
+ index col_end = std::min( (cur_stripe+1) * block_size, nr_columns );
+ for( index cur_col = col_begin; cur_col < col_end; cur_col++ )
+ if( boundary_matrix.get_dim( cur_col ) == cur_dim && boundary_matrix.get_max_index( cur_col ) != -1 )
+ unreduced_cols_cur_pass[ cur_stripe ].push_back( cur_col );
+ }
+ for( index cur_pass = 0; cur_pass < num_stripes; cur_pass++ ) {
+ boundary_matrix.sync();
+ #pragma omp parallel for schedule( guided, 1 )
+ for( int cur_stripe = 0; cur_stripe < num_stripes; cur_stripe++ ) {
+ index row_begin = (cur_stripe - cur_pass) * block_size;
+ index row_end = row_begin + block_size;
+ unreduced_cols_next_pass[ cur_stripe ].clear();
+ for( index idx = 0; idx < (index)unreduced_cols_cur_pass[ cur_stripe ].size(); idx++ ) {
+ index cur_col = unreduced_cols_cur_pass[ cur_stripe ][ idx ];
+ index lowest_one = boundary_matrix.get_max_index( cur_col );
+ while( lowest_one != -1 && lowest_one >= row_begin && lowest_one < row_end && lowest_one_lookup[ lowest_one ] != -1 ) {
+ boundary_matrix.add_to( lowest_one_lookup[ lowest_one ], cur_col );
+ lowest_one = boundary_matrix.get_max_index( cur_col );
+ }
+ if( lowest_one != -1 ) {
+ if( lowest_one >= row_begin && lowest_one < row_end ) {
+ lowest_one_lookup[ lowest_one ] = cur_col;
+ boundary_matrix.clear( lowest_one );
+ boundary_matrix.finalize( cur_col );
+ } else {
+ unreduced_cols_next_pass[ cur_stripe ].push_back( cur_col );
+ }
+ }
+ }
+ unreduced_cols_next_pass[ cur_stripe ].swap( unreduced_cols_cur_pass[ cur_stripe ] );
+ }
+ }
+ }
+ }
+ };
+}
diff --git a/matching/include/phat/algorithms/standard_reduction.h b/matching/include/phat/algorithms/standard_reduction.h
new file mode 100644
index 0000000..e490a5e
--- /dev/null
+++ b/matching/include/phat/algorithms/standard_reduction.h
@@ -0,0 +1,47 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+#include <phat/boundary_matrix.h>
+
+namespace phat {
+ class standard_reduction {
+ public:
+ template< typename Representation >
+ void operator() ( boundary_matrix< Representation >& boundary_matrix ) {
+
+ const index nr_columns = boundary_matrix.get_num_cols();
+ std::vector< index > lowest_one_lookup( nr_columns, -1 );
+
+ for( index cur_col = 0; cur_col < nr_columns; cur_col++ ) {
+ index lowest_one = boundary_matrix.get_max_index( cur_col );
+ while( lowest_one != -1 && lowest_one_lookup[ lowest_one ] != -1 ) {
+ boundary_matrix.add_to( lowest_one_lookup[ lowest_one ], cur_col );
+ lowest_one = boundary_matrix.get_max_index( cur_col );
+ }
+ if( lowest_one != -1 ) {
+ lowest_one_lookup[ lowest_one ] = cur_col;
+ }
+ boundary_matrix.finalize( cur_col );
+ }
+ }
+ };
+}
+
diff --git a/matching/include/phat/algorithms/twist_reduction.h b/matching/include/phat/algorithms/twist_reduction.h
new file mode 100644
index 0000000..2357df0
--- /dev/null
+++ b/matching/include/phat/algorithms/twist_reduction.h
@@ -0,0 +1,51 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+#include <phat/boundary_matrix.h>
+
+namespace phat {
+ class twist_reduction {
+ public:
+ template< typename Representation >
+ void operator () ( boundary_matrix< Representation >& boundary_matrix ) {
+
+ const index nr_columns = boundary_matrix.get_num_cols();
+ std::vector< index > lowest_one_lookup( nr_columns, -1 );
+
+ for( index cur_dim = boundary_matrix.get_max_dim(); cur_dim >= 1 ; cur_dim-- ) {
+ for( index cur_col = 0; cur_col < nr_columns; cur_col++ ) {
+ if( boundary_matrix.get_dim( cur_col ) == cur_dim ) {
+ index lowest_one = boundary_matrix.get_max_index( cur_col );
+ while( lowest_one != -1 && lowest_one_lookup[ lowest_one ] != -1 ) {
+ boundary_matrix.add_to( lowest_one_lookup[ lowest_one ], cur_col );
+ lowest_one = boundary_matrix.get_max_index( cur_col );
+ }
+ if( lowest_one != -1 ) {
+ lowest_one_lookup[ lowest_one ] = cur_col;
+ boundary_matrix.clear( lowest_one );
+ }
+ boundary_matrix.finalize( cur_col );
+ }
+ }
+ }
+ }
+ };
+}
diff --git a/matching/include/phat/boundary_matrix.h b/matching/include/phat/boundary_matrix.h
new file mode 100644
index 0000000..10c66cc
--- /dev/null
+++ b/matching/include/phat/boundary_matrix.h
@@ -0,0 +1,343 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+#include <phat/representations/bit_tree_pivot_column.h>
+
+// interface class for the main data structure -- implementations of the interface can be found in ./representations
+namespace phat {
+ template< class Representation = bit_tree_pivot_column >
+ class boundary_matrix
+ {
+
+ protected:
+ Representation rep;
+
+ // interface functions -- actual implementation and complexity depends on chosen @Representation template
+ public:
+ // get overall number of columns in boundary_matrix
+ index get_num_cols() const { return rep._get_num_cols(); }
+
+ // set overall number of columns in boundary_matrix
+ void set_num_cols( index nr_of_columns ) { rep._set_num_cols( nr_of_columns ); }
+
+ // get dimension of given index
+ dimension get_dim( index idx ) const { return rep._get_dim( idx ); }
+
+ // set dimension of given index
+ void set_dim( index idx, dimension dim ) { rep._set_dim( idx, dim ); }
+
+ // replaces content of @col with boundary of given index
+ void get_col( index idx, column& col ) const { col.clear(); rep._get_col( idx, col ); }
+
+ // set column @idx to the values contained in @col
+ void set_col( index idx, const column& col ) { rep._set_col( idx, col ); }
+
+ // true iff boundary of given column is empty
+ bool is_empty( index idx ) const { return rep._is_empty( idx ); }
+
+ // largest index of given column (new name for lowestOne()) -- NOT thread-safe
+ index get_max_index( index idx ) const { return rep._get_max_index( idx ); }
+
+ // removes maximal index from given column
+ void remove_max( index idx ) { rep._remove_max( idx ); }
+
+ // adds column @source to column @target'
+ void add_to( index source, index target ) { rep._add_to( source, target ); }
+
+ // clears given column
+ void clear( index idx ) { rep._clear( idx ); }
+
+ // finalizes given column
+ void finalize( index idx ) { rep._finalize( idx ); }
+
+ // syncronizes all internal data structures -- has to be called before and after any multithreaded access!
+ void sync() { rep._sync(); }
+
+ // info functions -- independent of chosen 'Representation'
+ public:
+ // maximal dimension
+ dimension get_max_dim() const {
+ dimension cur_max_dim = 0;
+ for( index idx = 0; idx < get_num_cols(); idx++ )
+ cur_max_dim = get_dim( idx ) > cur_max_dim ? get_dim( idx ) : cur_max_dim;
+ return cur_max_dim;
+ }
+
+ // number of nonzero rows for given column @idx
+ index get_num_rows( index idx ) const {
+ column cur_col;
+ get_col( idx, cur_col );
+ return cur_col.size();
+ }
+
+ // maximal number of nonzero rows of all columns
+ index get_max_col_entries() const {
+ index max_col_entries = -1;
+ const index nr_of_columns = get_num_cols();
+ for( index idx = 0; idx < nr_of_columns; idx++ )
+ max_col_entries = get_num_rows( idx ) > max_col_entries ? get_num_rows( idx ) : max_col_entries;
+ return max_col_entries;
+ }
+
+ // maximal number of nonzero cols of all rows
+ index get_max_row_entries() const {
+ size_t max_row_entries = 0;
+ const index nr_of_columns = get_num_cols();
+ std::vector< std::vector< index > > transposed_matrix( nr_of_columns );
+ column temp_col;
+ for( index cur_col = 0; cur_col < nr_of_columns; cur_col++ ) {
+ get_col( cur_col, temp_col );
+ for( index idx = 0; idx < (index)temp_col.size(); idx++)
+ transposed_matrix[ temp_col[ idx ] ].push_back( cur_col );
+ }
+ for( index idx = 0; idx < nr_of_columns; idx++ )
+ max_row_entries = transposed_matrix[ idx ].size() > max_row_entries ? transposed_matrix[ idx ].size() : max_row_entries;
+ return max_row_entries;
+ }
+
+ // overall number of entries in the matrix
+ index get_num_entries() const {
+ index number_of_nonzero_entries = 0;
+ const index nr_of_columns = get_num_cols();
+ for( index idx = 0; idx < nr_of_columns; idx++ )
+ number_of_nonzero_entries += get_num_rows( idx );
+ return number_of_nonzero_entries;
+ }
+
+ // operators / constructors
+ public:
+ boundary_matrix() {};
+
+ template< class OtherRepresentation >
+ boundary_matrix( const boundary_matrix< OtherRepresentation >& other ) {
+ *this = other;
+ }
+
+ template< typename OtherRepresentation >
+ bool operator==( const boundary_matrix< OtherRepresentation >& other_boundary_matrix ) const {
+ const index number_of_columns = this->get_num_cols();
+
+ if( number_of_columns != other_boundary_matrix.get_num_cols() )
+ return false;
+
+ column temp_col;
+ column other_temp_col;
+ for( index idx = 0; idx < number_of_columns; idx++ ) {
+ this->get_col( idx, temp_col );
+ other_boundary_matrix.get_col( idx, other_temp_col );
+ if( temp_col != other_temp_col || this->get_dim( idx ) != other_boundary_matrix.get_dim( idx ) )
+ return false;
+ }
+ return true;
+ }
+
+ template< typename OtherRepresentation >
+ bool operator!=( const boundary_matrix< OtherRepresentation >& other_boundary_matrix ) const {
+ return !( *this == other_boundary_matrix );
+ }
+
+ template< typename OtherRepresentation >
+ boundary_matrix< Representation >& operator=( const boundary_matrix< OtherRepresentation >& other )
+ {
+ const index nr_of_columns = other.get_num_cols();
+ this->set_num_cols( nr_of_columns );
+ column temp_col;
+ for( index cur_col = 0; cur_col < nr_of_columns; cur_col++ ) {
+ this->set_dim( cur_col, other.get_dim( cur_col ) );
+ other.get_col( cur_col, temp_col );
+ this->set_col( cur_col, temp_col );
+ }
+
+ // by convention, always return *this
+ return *this;
+ }
+
+ // I/O -- independent of chosen 'Representation'
+ public:
+
+ // initializes boundary_matrix from (vector<vector>, vector) pair -- untested
+ template< typename index_type, typename dimemsion_type >
+ void load_vector_vector( const std::vector< std::vector< index_type > >& input_matrix, const std::vector< dimemsion_type >& input_dims ) {
+ const index nr_of_columns = (index)input_matrix.size();
+ this->set_num_cols( nr_of_columns );
+ column temp_col;
+ #pragma omp parallel for private( temp_col )
+ for( index cur_col = 0; cur_col < nr_of_columns; cur_col++ ) {
+ this->set_dim( cur_col, (dimension)input_dims[ cur_col ] );
+
+ index num_rows = input_matrix[ cur_col ].size();
+ temp_col.resize( num_rows );
+ for( index cur_row = 0; cur_row < num_rows; cur_row++ )
+ temp_col[ cur_row ] = (index)input_matrix[ cur_col ][ cur_row ];
+ this->set_col( cur_col, temp_col );
+ }
+ }
+
+ template< typename index_type, typename dimemsion_type >
+ void save_vector_vector( std::vector< std::vector< index_type > >& output_matrix, std::vector< dimemsion_type >& output_dims ) {
+ const index nr_of_columns = get_num_cols();
+ output_matrix.resize( nr_of_columns );
+ output_dims.resize( nr_of_columns );
+ column temp_col;
+ for( index cur_col = 0; cur_col < nr_of_columns; cur_col++ ) {
+ output_dims[ cur_col ] = (dimemsion_type)get_dim( cur_col );
+ get_col( cur_col, temp_col );
+ index num_rows = temp_col.size();
+ output_matrix[ cur_col ].clear();
+ output_matrix[ cur_col ].resize( num_rows );
+ for( index cur_row = 0; cur_row < num_rows; cur_row++ )
+ output_matrix[ cur_col ][ cur_row ] = (index_type)temp_col[ cur_row ];
+ }
+ }
+
+
+ // Loads the boundary_matrix from given file in ascii format
+ // Format: each line represents a column, first number is dimension, other numbers are the content of the column.
+ // Ignores empty lines and lines starting with a '#'.
+ bool load_ascii( std::string filename ) {
+ // first count number of columns:
+ std::string cur_line;
+ std::ifstream dummy( filename .c_str() );
+ if( dummy.fail() )
+ return false;
+
+ index number_of_columns = 0;
+ while( getline( dummy, cur_line ) ) {
+ cur_line.erase(cur_line.find_last_not_of(" \t\n\r\f\v") + 1);
+ if( cur_line != "" && cur_line[ 0 ] != '#' )
+ number_of_columns++;
+
+ }
+ this->set_num_cols( number_of_columns );
+ dummy.close();
+
+ std::ifstream input_stream( filename.c_str() );
+ if( input_stream.fail() )
+ return false;
+
+ column temp_col;
+ index cur_col = -1;
+ while( getline( input_stream, cur_line ) ) {
+ cur_line.erase(cur_line.find_last_not_of(" \t\n\r\f\v") + 1);
+ if( cur_line != "" && cur_line[ 0 ] != '#' ) {
+ cur_col++;
+ std::stringstream ss( cur_line );
+
+ int64_t temp_dim;
+ ss >> temp_dim;
+ this->set_dim( cur_col, (dimension) temp_dim );
+
+ int64_t temp_index;
+ temp_col.clear();
+ while( ss.good() ) {
+ ss >> temp_index;
+ temp_col.push_back( (index)temp_index );
+ }
+ std::sort( temp_col.begin(), temp_col.end() );
+ this->set_col( cur_col, temp_col );
+ }
+ }
+
+ input_stream.close();
+ return true;
+ }
+
+ // Saves the boundary_matrix to given file in ascii format
+ // Format: each line represents a column, first number is dimension, other numbers are the content of the column
+ bool save_ascii( std::string filename ) {
+ std::ofstream output_stream( filename.c_str() );
+ if( output_stream.fail() )
+ return false;
+
+ const index nr_columns = this->get_num_cols();
+ column tempCol;
+ for( index cur_col = 0; cur_col < nr_columns; cur_col++ ) {
+ output_stream << (int64_t)this->get_dim( cur_col );
+ this->get_col( cur_col, tempCol );
+ for( index cur_row_idx = 0; cur_row_idx < (index)tempCol.size(); cur_row_idx++ )
+ output_stream << " " << tempCol[ cur_row_idx ];
+ output_stream << std::endl;
+ }
+
+ output_stream.close();
+ return true;
+ }
+
+ // Loads boundary_matrix from given file
+ // Format: nr_columns % dim1 % N1 % row1 row2 % ...% rowN1 % dim2 % N2 % ...
+ bool load_binary( std::string filename )
+ {
+ std::ifstream input_stream( filename.c_str( ), std::ios_base::binary | std::ios_base::in );
+ if( input_stream.fail( ) )
+ return false;
+
+ int64_t nr_columns;
+ input_stream.read( (char*)&nr_columns, sizeof( int64_t ) );
+ this->set_num_cols( (index)nr_columns );
+
+ column temp_col;
+ for( index cur_col = 0; cur_col < nr_columns; cur_col++ ) {
+ int64_t cur_dim;
+ input_stream.read( (char*)&cur_dim, sizeof( int64_t ) );
+ this->set_dim( cur_col, (dimension)cur_dim );
+ int64_t nr_rows;
+ input_stream.read( (char*)&nr_rows, sizeof( int64_t ) );
+ temp_col.resize( ( std::size_t )nr_rows );
+ for( index idx = 0; idx < nr_rows; idx++ ) {
+ int64_t cur_row;
+ input_stream.read( (char*)&cur_row, sizeof( int64_t ) );
+ temp_col[ idx ] = (index)cur_row;
+ }
+ this->set_col( cur_col, temp_col );
+ }
+
+ input_stream.close( );
+ return true;
+ }
+
+ // Saves the boundary_matrix to given file in binary format
+ // Format: nr_columns % dim1 % N1 % row1 row2 % ...% rowN1 % dim2 % N2 % ...
+ bool save_binary( std::string filename )
+ {
+ std::ofstream output_stream( filename.c_str( ), std::ios_base::binary | std::ios_base::out );
+ if( output_stream.fail( ) )
+ return false;
+
+ const int64_t nr_columns = this->get_num_cols( );
+ output_stream.write( (char*)&nr_columns, sizeof( int64_t ) );
+ column tempCol;
+ for( index cur_col = 0; cur_col < nr_columns; cur_col++ ) {
+ int64_t cur_dim = this->get_dim( cur_col );
+ output_stream.write( (char*)&cur_dim, sizeof( int64_t ) );
+ this->get_col( cur_col, tempCol );
+ int64_t cur_nr_rows = tempCol.size( );
+ output_stream.write( (char*)&cur_nr_rows, sizeof( int64_t ) );
+ for( index cur_row_idx = 0; cur_row_idx < (index)tempCol.size( ); cur_row_idx++ ) {
+ int64_t cur_row = tempCol[ cur_row_idx ];
+ output_stream.write( (char*)&cur_row, sizeof( int64_t ) );
+ }
+ }
+
+ output_stream.close( );
+ return true;
+ }
+ };
+}
diff --git a/matching/include/phat/compute_persistence_pairs.h b/matching/include/phat/compute_persistence_pairs.h
new file mode 100644
index 0000000..06f5372
--- /dev/null
+++ b/matching/include/phat/compute_persistence_pairs.h
@@ -0,0 +1,137 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/persistence_pairs.h>
+#include <phat/boundary_matrix.h>
+#include <phat/helpers/dualize.h>
+#include <phat/algorithms/twist_reduction.h>
+
+namespace phat {
+ // Extracts persistence pairs in separate dimensions from a reduced
+ // boundary matrix representing ``double`` filtration. The pairs
+ // give persistent relative homology of the pair of filtrations.
+ // TODO: Use it with standard reduction algorithm (no template option).
+ template< typename ReductionAlgorithm, typename Representation >
+ void compute_relative_persistence_pairs(std::vector<persistence_pairs>& pairs, boundary_matrix<Representation>& boundary_matrix, const std::map<int, int>& L) {
+ ReductionAlgorithm reduce;
+ reduce(boundary_matrix);
+ std::map<int, bool> free;
+ std::map<int, int> invL;
+ for (std::map<int, int>::const_iterator it = L.begin(); it != L.end(); ++it) { invL[it->second] = it->first; }
+ for (std::vector<persistence_pairs>::iterator it = pairs.begin(); it != pairs.end(); ++it) { it->clear(); }
+ for (index idx = 0; idx < boundary_matrix.get_num_cols(); ++idx) {
+ int dimension = boundary_matrix.get_dim(idx);
+ if (L.find(idx) != L.end()) { ++dimension; }
+ free[idx] = true;
+ if (!boundary_matrix.is_empty(idx)) {
+ index birth = boundary_matrix.get_max_index(idx);
+ index death = idx;
+ pairs[dimension-1].append_pair(birth, death);
+ free[birth] = false;
+ free[death] = false;
+ } else {
+ // This is an L-simplex and a (dimension+1)-dimensional cycle
+ if (L.find(idx) != L.end()) {
+ assert(dimension < pairs.size());
+ pairs[dimension].append_pair(idx, -1);
+ }
+ }
+ }
+ for (std::map<int, bool>::iterator it = free.begin(); it != free.end(); ++it) {
+ if (it->second) {
+ int dimension = boundary_matrix.get_dim(it->first);
+ if (invL.find(it->first) == invL.end() && L.find(it->first) == L.end()) {
+ assert(dimension < pairs.size());
+ pairs[dimension].append_pair(it->first, -1);
+ }
+ }
+ }
+ }
+
+ // Extracts persistence pairs in separate dimensions; expects a d-dimensional vector of persistent_pairs
+ template< typename ReductionAlgorithm, typename Representation >
+ void compute_persistence_pairs(std::vector<persistence_pairs>& pairs, boundary_matrix<Representation>& boundary_matrix) {
+ ReductionAlgorithm reduce;
+ reduce(boundary_matrix);
+ std::map<int, bool> free;
+ for (std::vector<persistence_pairs>::iterator it = pairs.begin(); it != pairs.end(); ++it) { it->clear(); }
+ for (index idx = 0; idx < boundary_matrix.get_num_cols(); ++idx) {
+ int dimension = boundary_matrix.get_dim(idx);
+ free[idx] = true;
+ if (!boundary_matrix.is_empty(idx)) {
+ index birth = boundary_matrix.get_max_index(idx);
+ index death = idx;
+ pairs[dimension-1].append_pair(birth, death);
+ // Cannot be of the form (a, infinity)
+ free[birth] = false;
+ free[death] = false;
+ }
+ }
+ for (std::map<int, bool>::iterator it = free.begin(); it != free.end(); ++it) {
+ if (it->second) {
+ int dimension = boundary_matrix.get_dim(it->first);
+ pairs[dimension].append_pair(it->first, -1);
+ }
+ }
+ }
+
+ template< typename ReductionAlgorithm, typename Representation >
+ void compute_persistence_pairs( persistence_pairs& pairs, boundary_matrix< Representation >& boundary_matrix ) {
+ ReductionAlgorithm reduce;
+ reduce( boundary_matrix );
+ pairs.clear();
+ std::set<index> max_indices;
+ // finite pairs
+ for( index idx = 0; idx < boundary_matrix.get_num_cols(); idx++ ) {
+ if( !boundary_matrix.is_empty( idx ) ) {
+ index birth = boundary_matrix.get_max_index( idx );
+ max_indices.insert(birth);
+ index death = idx;
+ pairs.append_pair( birth, death );
+ }
+ }
+ // infinite pairs: column idx is 0, and row idx does not contain a lowest one
+ for( index idx = 0; idx < boundary_matrix.get_num_cols(); idx++ ) {
+ if(boundary_matrix.is_empty(idx) && max_indices.count(idx) == 0 ) {
+ pairs.append_pair( idx, k_infinity_index);
+ }
+ }
+ }
+
+ template< typename ReductionAlgorithm, typename Representation >
+ void compute_persistence_pairs_dualized( persistence_pairs& pairs, boundary_matrix< Representation >& boundary_matrix ) {
+
+ dualize( boundary_matrix );
+ compute_persistence_pairs< ReductionAlgorithm >( pairs, boundary_matrix );
+ dualize_persistence_pairs( pairs, boundary_matrix.get_num_cols() );
+ }
+
+ template< typename Representation >
+ void compute_persistence_pairs( persistence_pairs& pairs, boundary_matrix< Representation >& boundary_matrix ) {
+ phat::compute_persistence_pairs< twist_reduction >( pairs, boundary_matrix );
+ }
+
+
+ template< typename Representation >
+ void compute_persistence_pairs_dualized( persistence_pairs& pairs, boundary_matrix< Representation >& boundary_matrix ) {
+ compute_persistence_pairs_dualized< twist_reduction >( pairs, boundary_matrix );
+ }
+
+}
diff --git a/matching/include/phat/helpers/dualize.h b/matching/include/phat/helpers/dualize.h
new file mode 100644
index 0000000..5731408
--- /dev/null
+++ b/matching/include/phat/helpers/dualize.h
@@ -0,0 +1,74 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+#include <phat/boundary_matrix.h>
+#include <phat/persistence_pairs.h>
+
+
+namespace phat {
+ template< typename Representation >
+ void dualize( boundary_matrix< Representation >& boundary_matrix ) {
+
+ std::vector< dimension > dual_dims;
+ std::vector< std::vector< index > > dual_matrix;
+
+ index nr_of_columns = boundary_matrix.get_num_cols();
+ dual_matrix.resize( nr_of_columns );
+ dual_dims.resize( nr_of_columns );
+
+ std::vector< index > dual_sizes( nr_of_columns, 0 );
+
+ column temp_col;
+ for( index cur_col = 0; cur_col < nr_of_columns; cur_col++ ) {
+ boundary_matrix.get_col( cur_col, temp_col );
+ for( index idx = 0; idx < (index)temp_col.size(); idx++)
+ dual_sizes[ nr_of_columns - 1 - temp_col[ idx ] ]++;
+ }
+
+ #pragma omp parallel for
+ for( index cur_col = 0; cur_col < nr_of_columns; cur_col++ )
+ dual_matrix[cur_col].reserve(dual_sizes[cur_col]);
+
+ for( index cur_col = 0; cur_col < nr_of_columns; cur_col++ ) {
+ boundary_matrix.get_col( cur_col, temp_col );
+ for( index idx = 0; idx < (index)temp_col.size(); idx++)
+ dual_matrix[ nr_of_columns - 1 - temp_col[ idx ] ].push_back( nr_of_columns - 1 - cur_col );
+ }
+
+ const dimension max_dim = boundary_matrix.get_max_dim();
+ #pragma omp parallel for
+ for( index cur_col = 0; cur_col < nr_of_columns; cur_col++ )
+ dual_dims[ nr_of_columns - 1 - cur_col ] = max_dim - boundary_matrix.get_dim( cur_col );
+
+ #pragma omp parallel for
+ for( index cur_col = 0; cur_col < nr_of_columns; cur_col++ )
+ std::reverse( dual_matrix[ cur_col ].begin(), dual_matrix[ cur_col ].end() );
+
+ boundary_matrix.load_vector_vector( dual_matrix, dual_dims );
+ }
+
+ inline void dualize_persistence_pairs( persistence_pairs& pairs, const index n ) {
+ for (index i = 0; i < pairs.get_num_pairs(); ++i) {
+ std::pair< index, index > pair = pairs.get_pair( i );
+ pairs.set_pair( i , n - 1 - pair.second, n - 1 - pair.first);
+ }
+ }
+}
diff --git a/matching/include/phat/helpers/misc.h b/matching/include/phat/helpers/misc.h
new file mode 100644
index 0000000..5a5c682
--- /dev/null
+++ b/matching/include/phat/helpers/misc.h
@@ -0,0 +1,78 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+// STL includes
+#include <iostream>
+#include <fstream>
+#include <string>
+#include <vector>
+#include <set>
+#include <list>
+#include <map>
+#include <algorithm>
+#include <queue>
+#include <cassert>
+#include <sstream>
+#include <algorithm>
+#include <iomanip>
+#include <cmath>
+#include <cstdlib>
+#include <iterator>
+#include <limits>
+
+// VS2008 and below unfortunately do not support stdint.h
+#if defined(_MSC_VER)&& _MSC_VER < 1600
+ typedef __int8 int8_t;
+ typedef unsigned __int8 uint8_t;
+ typedef __int16 int16_t;
+ typedef unsigned __int16 uint16_t;
+ typedef __int32 int32_t;
+ typedef unsigned __int32 uint32_t;
+ typedef __int64 int64_t;
+ typedef unsigned __int64 uint64_t;
+#else
+ #include <stdint.h>
+#endif
+
+// basic types. index can be changed to int32_t to save memory on small instances
+namespace phat {
+ typedef int64_t index;
+ typedef int8_t dimension;
+ typedef std::vector< index > column;
+
+ constexpr index k_infinity_index = std::numeric_limits<index>::max();
+}
+
+// OpenMP (proxy) functions
+#if defined _OPENMP
+ #include <omp.h>
+#else
+ #define omp_get_thread_num() 0
+ #define omp_get_max_threads() 1
+ #define omp_get_num_threads() 1
+ inline void omp_set_num_threads( int ) {};
+ #include <time.h>
+ #define omp_get_wtime() (float)clock() / (float)CLOCKS_PER_SEC
+#endif
+
+#include <phat/helpers/thread_local_storage.h>
+
+
+
diff --git a/matching/include/phat/helpers/thread_local_storage.h b/matching/include/phat/helpers/thread_local_storage.h
new file mode 100644
index 0000000..d0b5332
--- /dev/null
+++ b/matching/include/phat/helpers/thread_local_storage.h
@@ -0,0 +1,52 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+
+// should ideally be equal to the cache line size of the CPU
+#define PHAT_TLS_SPACING_FACTOR 64
+
+// ThreadLocalStorage with some spacing to avoid "false sharing" (see wikipedia)
+template< typename T >
+class thread_local_storage
+{
+public:
+
+ thread_local_storage() : per_thread_storage( omp_get_max_threads() * PHAT_TLS_SPACING_FACTOR ) {};
+
+ T& operator()() {
+ return per_thread_storage[ omp_get_thread_num() * PHAT_TLS_SPACING_FACTOR ];
+ }
+
+ const T& operator()() const {
+ return per_thread_storage[ omp_get_thread_num() * PHAT_TLS_SPACING_FACTOR ];
+ }
+
+ T& operator[]( int tid ) {
+ return per_thread_storage[ tid * PHAT_TLS_SPACING_FACTOR ];
+ }
+
+ const T& operator[]( int tid ) const {
+ return per_thread_storage[ tid * PHAT_TLS_SPACING_FACTOR ];
+ }
+
+protected:
+ std::vector< T > per_thread_storage;
+};
diff --git a/matching/include/phat/persistence_pairs.h b/matching/include/phat/persistence_pairs.h
new file mode 100644
index 0000000..eafc638
--- /dev/null
+++ b/matching/include/phat/persistence_pairs.h
@@ -0,0 +1,155 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+
+namespace phat {
+ class persistence_pairs {
+
+ protected:
+ std::vector< std::pair< index, index > > pairs;
+
+ public:
+ index get_num_pairs() const {
+ return (index)pairs.size();
+ }
+
+ void append_pair( index birth, index death ) {
+ pairs.push_back( std::make_pair( birth, death ) );
+ }
+
+ std::pair< index, index > get_pair( index idx ) const {
+ return pairs[ idx ];
+ }
+
+ void set_pair( index idx, index birth, index death ) {
+ pairs[ idx ] = std::make_pair( birth, death );
+ }
+
+ void clear() {
+ pairs.clear();
+ }
+
+ void sort() {
+ std::sort( pairs.begin(), pairs.end() );
+ }
+
+ // Loads the persistence pairs from given file in asci format
+ // Format: nr_pairs % newline % birth1 % death1 % newline % birth2 % death2 % newline ...
+ bool load_ascii( std::string filename ) {
+ std::ifstream input_stream( filename.c_str() );
+ if( input_stream.fail() )
+ return false;
+
+ int64_t nr_pairs;
+ input_stream >> nr_pairs;
+ pairs.clear();
+ for( index idx = 0; idx < nr_pairs; idx++ ) {
+ int64_t birth;
+ input_stream >> birth;
+ int64_t death;
+ input_stream >> death;
+ append_pair( (index)birth, (index)death );
+ }
+
+ input_stream.close();
+ return true;
+ }
+
+ // Saves the persistence pairs to given file in binary format
+ // Format: nr_pairs % newline % birth1 % death1 % newline % birth2 % death2 % newline ...
+ bool save_ascii( std::string filename ) {
+ std::ofstream output_stream( filename.c_str() );
+ if( output_stream.fail() )
+ return false;
+
+ this->sort();
+ output_stream << get_num_pairs() << std::endl;
+ for( std::size_t idx = 0; idx < pairs.size(); idx++ ) {
+ output_stream << pairs[idx].first << " " << pairs[idx].second << std::endl;
+ }
+
+ output_stream.close();
+ return true;
+ }
+
+ // Loads the persistence pairs from given file in binary format
+ // Format: nr_pairs % birth1 % death1 % birth2 % death2 ...
+ bool load_binary( std::string filename ) {
+ std::ifstream input_stream( filename.c_str(), std::ios_base::binary | std::ios_base::in );
+ if( input_stream.fail() )
+ return false;
+
+ int64_t nr_pairs;
+ input_stream.read( (char*)&nr_pairs, sizeof( int64_t ) );
+ for( index idx = 0; idx < nr_pairs; idx++ ) {
+ int64_t birth;
+ input_stream.read( (char*)&birth, sizeof( int64_t ) );
+ int64_t death;
+ input_stream.read( (char*)&death, sizeof( int64_t ) );
+ append_pair( (index)birth, (index)death );
+ }
+
+ input_stream.close();
+ return true;
+ }
+
+ // Saves the persistence pairs to given file in binary format
+ // Format: nr_pairs % birth1 % death1 % birth2 % death2 ...
+ bool save_binary( std::string filename ) {
+ std::ofstream output_stream( filename.c_str(), std::ios_base::binary | std::ios_base::out );
+ if( output_stream.fail() )
+ return false;
+
+ this->sort();
+ int64_t nr_pairs = get_num_pairs();
+ output_stream.write( (char*)&nr_pairs, sizeof( int64_t ) );
+ for( std::size_t idx = 0; idx < pairs.size(); idx++ ) {
+ int64_t birth = pairs[ idx ].first;
+ output_stream.write( (char*)&birth, sizeof( int64_t ) );
+ int64_t death = pairs[ idx ].second;
+ output_stream.write( (char*)&death, sizeof( int64_t ) );
+ }
+
+ output_stream.close();
+ return true;
+ }
+
+ bool operator==( persistence_pairs& other_pairs ) {
+ this->sort();
+ other_pairs.sort();
+ if( pairs.size() != (std::size_t)other_pairs.get_num_pairs() )
+ return false;
+
+ for( index idx = 0; idx < (index)pairs.size(); idx++ )
+ if( get_pair( idx ) != other_pairs.get_pair( idx ) )
+ return false;
+
+ return true;
+ }
+
+ bool operator!=( persistence_pairs& other_pairs ) {
+ return !( *this == other_pairs );
+ }
+ };
+
+
+
+}
diff --git a/matching/include/phat/representations/abstract_pivot_column.h b/matching/include/phat/representations/abstract_pivot_column.h
new file mode 100644
index 0000000..e16d7a5
--- /dev/null
+++ b/matching/include/phat/representations/abstract_pivot_column.h
@@ -0,0 +1,102 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+#include <phat/representations/vector_vector.h>
+
+namespace phat {
+
+ // Note: We could even make the rep generic in the underlying Const representation
+ // But I cannot imagine that anything else than vector<vector<index>> would
+ // make sense
+ template< typename PivotColumn >
+ class abstract_pivot_column : public vector_vector {
+
+ protected:
+ typedef vector_vector Base;
+ typedef PivotColumn pivot_col;
+
+ // For parallization purposes, it could be more than one full column
+ mutable thread_local_storage< pivot_col > pivot_cols;
+ mutable thread_local_storage< index > idx_of_pivot_cols;
+
+ pivot_col& get_pivot_col() const {
+ return pivot_cols();
+ }
+
+ bool is_pivot_col( index idx ) const {
+ return idx_of_pivot_cols() == idx;
+ }
+
+ void release_pivot_col() {
+ index idx = idx_of_pivot_cols();
+ if( idx != -1 ) {
+ this->matrix[ idx ].clear();
+ pivot_cols().get_col_and_clear( this->matrix[ idx ] );
+ }
+ idx_of_pivot_cols() = -1;
+ }
+
+ void make_pivot_col( index idx ) {
+ release_pivot_col();
+ idx_of_pivot_cols() = idx;
+ get_pivot_col().add_col( matrix[ idx ] );
+ }
+
+ public:
+
+ void _set_num_cols( index nr_of_cols ) {
+ #pragma omp parallel for
+ for( int tid = 0; tid < omp_get_num_threads(); tid++ ) {
+ pivot_cols[ tid ].init( nr_of_cols );
+ idx_of_pivot_cols[ tid ] = -1;
+ }
+ Base::_set_num_cols( nr_of_cols );
+ }
+
+ void _add_to( index source, index target ) {
+ if( !is_pivot_col( target ) )
+ make_pivot_col( target );
+ get_pivot_col().add_col( matrix[source] );
+ }
+
+ void _sync() {
+ #pragma omp parallel for
+ for( int tid = 0; tid < omp_get_num_threads(); tid++ )
+ release_pivot_col();
+ }
+
+ void _get_col( index idx, column& col ) const { is_pivot_col( idx ) ? get_pivot_col().get_col( col ) : Base::_get_col( idx, col ); }
+
+ bool _is_empty( index idx ) const { return is_pivot_col( idx ) ? get_pivot_col().is_empty() : Base::_is_empty( idx ); }
+
+ index _get_max_index( index idx ) const { return is_pivot_col( idx ) ? get_pivot_col().get_max_index() : Base::_get_max_index( idx ); }
+
+ void _clear( index idx ) { is_pivot_col( idx ) ? get_pivot_col().clear() : Base::_clear( idx ); }
+
+ void _set_col( index idx, const column& col ) { is_pivot_col( idx ) ? get_pivot_col().set_col( col ) : Base::_set_col( idx, col ); }
+
+ void _remove_max( index idx ) { is_pivot_col( idx ) ? get_pivot_col().remove_max() : Base::_remove_max( idx ); }
+
+ void finalize( index idx ) { Base::_finalize( idx ); }
+ };
+}
+
+
diff --git a/matching/include/phat/representations/bit_tree_pivot_column.h b/matching/include/phat/representations/bit_tree_pivot_column.h
new file mode 100644
index 0000000..4d48e88
--- /dev/null
+++ b/matching/include/phat/representations/bit_tree_pivot_column.h
@@ -0,0 +1,165 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Hubert Wagner
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+#include <phat/representations/abstract_pivot_column.h>
+
+namespace phat {
+
+ // This is a bitset indexed with a 64-ary tree. Each node in the index
+ // has 64 bits; i-th bit says that the i-th subtree is non-empty.
+ // Supports practically O(1), inplace, zero-allocation: insert, remove, max_element
+ // and clear in O(number of ones in the bitset).
+ // 'add_index' is still the real bottleneck in practice.
+ class bit_tree_column
+ {
+ protected:
+
+ size_t offset; // data[i + offset] = ith block of the data-bitset
+ typedef uint64_t block_type;
+ std::vector< block_type > data;
+
+
+ size_t debrujin_magic_table[ 64 ];
+
+ enum { block_size_in_bits = 64 };
+ enum { block_shift = 6 };
+
+ // Some magic: http://graphics.stanford.edu/~seander/bithacks.html
+ // Gets the position of the rightmost bit of 'x'. 0 means the most significant bit.
+ // (-x)&x isolates the rightmost bit.
+ // The whole method is much faster than calling log2i, and very comparable to using ScanBitForward/Reverse intrinsic,
+ // which should be one CPU instruction, but is not portable.
+ size_t rightmost_pos( const block_type value ) const {
+ return 64 - 1 - debrujin_magic_table[ ( (value & (-(int64_t)value) ) * 0x07EDD5E59A4E28C2 ) >> 58 ];
+ }
+
+ public:
+
+ void init( index num_cols ) {
+ int64_t n = 1; // in case of overflow
+ int64_t bottom_blocks_needed = ( num_cols + block_size_in_bits - 1 ) / block_size_in_bits;
+ int64_t upper_blocks = 1;
+
+ // How many blocks/nodes of index needed to index the whole bitset?
+ while( n * block_size_in_bits < bottom_blocks_needed ) {
+ n *= block_size_in_bits;
+ upper_blocks += n;
+ }
+
+ offset = upper_blocks;
+ data.resize( upper_blocks + bottom_blocks_needed, 0 );
+
+ std::size_t temp_array[ 64 ] = {
+ 63, 0, 58, 1, 59, 47, 53, 2,
+ 60, 39, 48, 27, 54, 33, 42, 3,
+ 61, 51, 37, 40, 49, 18, 28, 20,
+ 55, 30, 34, 11, 43, 14, 22, 4,
+ 62, 57, 46, 52, 38, 26, 32, 41,
+ 50, 36, 17, 19, 29, 10, 13, 21,
+ 56, 45, 25, 31, 35, 16, 9, 12,
+ 44, 24, 15, 8, 23, 7, 6, 5 };
+
+ std::copy( &temp_array[ 0 ], &temp_array[ 64 ], &debrujin_magic_table[ 0 ] );
+ }
+
+ index get_max_index() const {
+ if( !data[ 0 ] )
+ return -1;
+
+ size_t n = 0;
+ size_t newn = 0;
+ size_t index = 0;
+ while( newn < data.size() ) {
+ n = newn;
+ index = rightmost_pos( data[ n ] );
+ newn = ( n << block_shift ) + index + 1;
+ }
+
+ return ( ( n - offset ) << block_shift ) + index;
+ }
+
+ bool is_empty() const {
+ return data[ 0 ] == 0;
+ }
+
+ void add_index( const size_t entry ) {
+ const block_type ONE = 1;
+ const block_type block_modulo_mask = ( ONE << block_shift ) - 1;
+ size_t index_in_level = entry >> block_shift;
+ size_t address = index_in_level + offset;
+ size_t index_in_block = entry & block_modulo_mask;
+
+ block_type mask = ( ONE << ( block_size_in_bits - index_in_block - 1 ) );
+
+ data[ address ] ^= mask;
+
+ // Check if we reached the root. Also, if anyone else was in this block, we don't need to update the path up.
+ while( address && !( data[ address ] & ~mask ) ) {
+ index_in_block = index_in_level & block_modulo_mask;
+ index_in_level >>= block_shift;
+ --address;
+ address >>= block_shift;
+ mask = ( ONE << ( block_size_in_bits - index_in_block - 1 ) );
+ data[ address ] ^= mask;
+ }
+ }
+
+ void get_col_and_clear( column &out ) {
+ index mx = this->get_max_index();
+ while( mx != -1 ) {
+ out.push_back( mx );
+ add_index( mx );
+ mx = this->get_max_index();
+ }
+
+ std::reverse( out.begin(), out.end() );
+ }
+
+ void add_col(const column &col) {
+ for( size_t i = 0; i < col.size(); ++i )
+ add_index(col[i]);
+ }
+
+ void clear() {
+ index mx = this->get_max_index();
+ while( mx != -1 ) {
+ add_index( mx );
+ mx = this->get_max_index();
+ }
+ }
+
+ void remove_max() {
+ add_index( get_max_index() );
+ }
+
+ void set_col( const column& col ) {
+ clear();
+ add_col( col );
+ }
+
+ void get_col( column& col ) {
+ get_col_and_clear( col );
+ add_col( col );
+ }
+ };
+
+ typedef abstract_pivot_column<bit_tree_column> bit_tree_pivot_column;
+}
diff --git a/matching/include/phat/representations/full_pivot_column.h b/matching/include/phat/representations/full_pivot_column.h
new file mode 100644
index 0000000..c2e9e3c
--- /dev/null
+++ b/matching/include/phat/representations/full_pivot_column.h
@@ -0,0 +1,100 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+#include <phat/representations/abstract_pivot_column.h>
+
+namespace phat {
+ class full_column {
+
+ protected:
+ std::priority_queue< index > history;
+ std::vector< char > is_in_history;
+ std::vector< char > col_bit_field;
+
+ public:
+ void init( const index total_size ) {
+ col_bit_field.resize( total_size, false );
+ is_in_history.resize( total_size, false );
+ }
+
+ void add_col( const column& col ) {
+ for( index idx = 0; idx < (index) col.size(); idx++ ) {
+ add_index( col[ idx ] );
+ }
+ }
+
+ void add_index( const index idx ) {
+ if( !is_in_history[ idx ] ) {
+ history.push( idx );
+ is_in_history[ idx ] = true;
+ }
+
+ col_bit_field[ idx ] = !col_bit_field[ idx ];
+ }
+
+ index get_max_index() {
+ while( history.size() > 0 ) {
+ index topIndex = history.top();
+ if( col_bit_field[ topIndex ] ) {
+ return topIndex;
+ } else {
+ history.pop();
+ is_in_history[ topIndex ] = false;
+ }
+ }
+
+ return -1;
+ }
+
+ void get_col_and_clear( column& col ) {
+ while( !is_empty() ) {
+ col.push_back( get_max_index() );
+ add_index( get_max_index() );
+ }
+ std::reverse( col.begin(), col.end() );
+ }
+
+ bool is_empty() {
+ return (get_max_index() == -1);
+ }
+
+ void clear() {
+ while( !is_empty() )
+ add_index( get_max_index() );
+ }
+
+ void remove_max() {
+ add_index( get_max_index() );
+ }
+
+ void set_col( const column& col ) {
+ clear();
+ add_col( col );
+ }
+
+ void get_col( column& col ) {
+ get_col_and_clear( col );
+ add_col( col );
+ }
+ };
+
+ typedef abstract_pivot_column< full_column > full_pivot_column;
+}
diff --git a/matching/include/phat/representations/heap_pivot_column.h b/matching/include/phat/representations/heap_pivot_column.h
new file mode 100644
index 0000000..33cd07b
--- /dev/null
+++ b/matching/include/phat/representations/heap_pivot_column.h
@@ -0,0 +1,126 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+#include <phat/representations/abstract_pivot_column.h>
+
+namespace phat {
+ class heap_column {
+
+ protected:
+ std::priority_queue< index > data;
+
+ column temp_col;
+ index inserts_since_last_prune;
+
+ void prune()
+ {
+ temp_col.clear( );
+ index max_index = pop_max_index( );
+ while( max_index != -1 ) {
+ temp_col.push_back( max_index );
+ max_index = pop_max_index( );
+ }
+
+ for( index idx = 0; idx < (index)temp_col.size( ); idx++ )
+ data.push( temp_col[ idx ] );
+
+ inserts_since_last_prune = 0;
+ }
+
+ index pop_max_index()
+ {
+ if( data.empty( ) )
+ return -1;
+ else {
+ index max_element = data.top( );
+ data.pop();
+ while( !data.empty( ) && data.top( ) == max_element ) {
+ data.pop( );
+ if( data.empty( ) )
+ return -1;
+ else {
+ max_element = data.top( );
+ data.pop( );
+ }
+ }
+ return max_element;
+ }
+ }
+
+ public:
+ void init( const index total_size ) {
+ inserts_since_last_prune = 0;
+ clear();
+ }
+
+ void add_col( const column& col ) {
+ for( index idx = 0; idx < (index) col.size(); idx++ )
+ data.push( col[ idx ] );
+ inserts_since_last_prune += col.size( );
+ if( 2 * inserts_since_last_prune >( index ) data.size( ) )
+ prune();
+ }
+
+ index get_max_index() {
+ index max_element = pop_max_index( );
+ if( max_element == -1 )
+ return -1;
+ else {
+ data.push( max_element );
+ return max_element;
+ }
+ }
+
+ void get_col_and_clear( column& col ) {
+ col.clear();
+ index max_index = pop_max_index( );
+ while( max_index != -1 ) {
+ col.push_back( max_index );
+ max_index = pop_max_index( );
+ }
+ std::reverse( col.begin(), col.end() );
+ }
+
+ bool is_empty() {
+ return get_max_index() == -1;
+ }
+
+ void clear() {
+ data = std::priority_queue< index >();
+ }
+
+ void remove_max() {
+ pop_max_index();
+ }
+
+ void set_col( const column& col ) {
+ clear();
+ add_col( col );
+ }
+
+ void get_col( column& col ) {
+ get_col_and_clear( col );
+ add_col( col );
+ }
+ };
+
+ typedef abstract_pivot_column< heap_column > heap_pivot_column;
+}
diff --git a/matching/include/phat/representations/sparse_pivot_column.h b/matching/include/phat/representations/sparse_pivot_column.h
new file mode 100644
index 0000000..390fd91
--- /dev/null
+++ b/matching/include/phat/representations/sparse_pivot_column.h
@@ -0,0 +1,79 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+#include <phat/representations/abstract_pivot_column.h>
+
+namespace phat {
+ class sparse_column {
+
+ protected:
+ std::set< index > data;
+
+ void add_index( const index idx ) {
+ std::pair< std::set< index >::iterator, bool > result = data.insert( idx );
+ if( result.second == false )
+ data.erase( result.first );
+ }
+
+ public:
+ void init( const index total_size ) {
+ data.clear();
+ }
+
+ void add_col( const column& col ) {
+ for( index idx = 0; idx < (index) col.size(); idx++ )
+ add_index( col[ idx ] );
+ }
+
+ index get_max_index() {
+ return data.empty() ? -1 : *data.rbegin();
+ }
+
+ void get_col_and_clear( column& col ) {
+ col.assign( data.begin(), data.end() );
+ data.clear();
+ }
+
+ bool is_empty() {
+ return data.empty();
+ }
+
+ void clear() {
+ data.clear();
+ }
+
+ void remove_max() {
+ add_index( get_max_index() );
+ }
+
+ void set_col( const column& col ) {
+ clear();
+ add_col( col );
+ }
+
+ void get_col( column& col ) {
+ get_col_and_clear( col );
+ add_col( col );
+ }
+ };
+
+ typedef abstract_pivot_column< sparse_column > sparse_pivot_column;
+}
diff --git a/matching/include/phat/representations/vector_heap.h b/matching/include/phat/representations/vector_heap.h
new file mode 100644
index 0000000..db0420f
--- /dev/null
+++ b/matching/include/phat/representations/vector_heap.h
@@ -0,0 +1,170 @@
+/* Copyright 2013 IST Austria
+Contributed by: Jan Reininghaus
+
+This file is part of PHAT.
+
+PHAT is free software: you can redistribute it and/or modify
+it under the terms of the GNU Lesser General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+PHAT is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public License
+along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+
+namespace phat {
+ class vector_heap {
+
+ protected:
+ std::vector< dimension > dims;
+ std::vector< column > matrix;
+
+ std::vector< index > inserts_since_last_prune;
+
+ mutable thread_local_storage< column > temp_column_buffer;
+
+ protected:
+ void _prune( index idx )
+ {
+ column& col = matrix[ idx ];
+ column& temp_col = temp_column_buffer();
+ temp_col.clear();
+ index max_index = _pop_max_index( col );
+ while( max_index != -1 ) {
+ temp_col.push_back( max_index );
+ max_index = _pop_max_index( col );
+ }
+ col = temp_col;
+ std::reverse( col.begin( ), col.end( ) );
+ std::make_heap( col.begin( ), col.end( ) );
+ inserts_since_last_prune[ idx ] = 0;
+ }
+
+ index _pop_max_index( index idx )
+ {
+ return _pop_max_index( matrix[ idx ] );
+ }
+
+ index _pop_max_index( column& col ) const
+ {
+ if( col.empty( ) )
+ return -1;
+ else {
+ index max_element = col.front( );
+ std::pop_heap( col.begin( ), col.end( ) );
+ col.pop_back( );
+ while( !col.empty( ) && col.front( ) == max_element ) {
+ std::pop_heap( col.begin( ), col.end( ) );
+ col.pop_back( );
+ if( col.empty( ) )
+ return -1;
+ else {
+ max_element = col.front( );
+ std::pop_heap( col.begin( ), col.end( ) );
+ col.pop_back( );
+ }
+ }
+ return max_element;
+ }
+ }
+
+ public:
+ // overall number of cells in boundary_matrix
+ index _get_num_cols( ) const
+ {
+ return (index)matrix.size( );
+ }
+ void _set_num_cols( index nr_of_columns )
+ {
+ dims.resize( nr_of_columns );
+ matrix.resize( nr_of_columns );
+ inserts_since_last_prune.assign( nr_of_columns, 0 );
+ }
+
+ // dimension of given index
+ dimension _get_dim( index idx ) const
+ {
+ return dims[ idx ];
+ }
+ void _set_dim( index idx, dimension dim )
+ {
+ dims[ idx ] = dim;
+ }
+
+ // replaces(!) content of 'col' with boundary of given index
+ void _get_col( index idx, column& col ) const
+ {
+ temp_column_buffer( ) = matrix[ idx ];
+
+ index max_index = _pop_max_index( temp_column_buffer() );
+ while( max_index != -1 ) {
+ col.push_back( max_index );
+ max_index = _pop_max_index( temp_column_buffer( ) );
+ }
+ std::reverse( col.begin( ), col.end( ) );
+ }
+ void _set_col( index idx, const column& col )
+ {
+ matrix[ idx ] = col;
+ std::make_heap( matrix[ idx ].begin( ), matrix[ idx ].end( ) );
+ }
+
+ // true iff boundary of given idx is empty
+ bool _is_empty( index idx ) const
+ {
+ return _get_max_index( idx ) == -1;
+ }
+
+ // largest row index of given column idx (new name for lowestOne())
+ index _get_max_index( index idx ) const
+ {
+ column& col = const_cast< column& >( matrix[ idx ] );
+ index max_element = _pop_max_index( col );
+ col.push_back( max_element );
+ std::push_heap( col.begin( ), col.end( ) );
+ return max_element;
+ }
+
+ // removes the maximal index of a column
+ void _remove_max( index idx )
+ {
+ _pop_max_index( idx );
+ }
+
+ // clears given column
+ void _clear( index idx )
+ {
+ matrix[ idx ].clear( );
+ }
+
+ // syncronizes all data structures (essential for openmp stuff)
+ void _sync( ) {}
+
+ // adds column 'source' to column 'target'
+ void _add_to( index source, index target )
+ {
+ for( index idx = 0; idx < (index)matrix[ source ].size( ); idx++ ) {
+ matrix[ target ].push_back( matrix[ source ][ idx ] );
+ std::push_heap( matrix[ target ].begin(), matrix[ target ].end() );
+ }
+ inserts_since_last_prune[ target ] += matrix[ source ].size();
+
+ if( 2 * inserts_since_last_prune[ target ] > ( index )matrix[ target ].size() )
+ _prune( target );
+ }
+
+ // finalizes given column
+ void _finalize( index idx ) {
+ _prune( idx );
+ }
+
+ };
+}
diff --git a/matching/include/phat/representations/vector_list.h b/matching/include/phat/representations/vector_list.h
new file mode 100644
index 0000000..ca0b5b8
--- /dev/null
+++ b/matching/include/phat/representations/vector_list.h
@@ -0,0 +1,101 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+
+namespace phat {
+ class vector_list {
+
+ protected:
+ std::vector< dimension > dims;
+ std::vector< std::list< index > > matrix;
+
+ public:
+ // overall number of cells in boundary_matrix
+ index _get_num_cols() const {
+ return (index)matrix.size();
+ }
+ void _set_num_cols( index nr_of_columns ) {
+ dims.resize( nr_of_columns );
+ matrix.resize( nr_of_columns );
+ }
+
+ // dimension of given index
+ dimension _get_dim( index idx ) const {
+ return dims[ idx ];
+ }
+ void _set_dim( index idx, dimension dim ) {
+ dims[ idx ] = dim;
+ }
+
+ // replaces(!) content of 'col' with boundary of given index
+ void _get_col( index idx, column& col ) const {
+ col.clear();
+ col.reserve( matrix[idx].size() );
+ std::copy (matrix[idx].begin(), matrix[idx].end(), std::back_inserter(col) );
+ }
+
+ void _set_col( index idx, const column& col ) {
+ matrix[ idx ].clear();
+ matrix[ idx ].resize( col.size() );
+ std::copy (col.begin(), col.end(), matrix[ idx ].begin() );
+ }
+
+ // true iff boundary of given idx is empty
+ bool _is_empty( index idx ) const {
+ return matrix[ idx ].empty();
+ }
+
+ // largest row index of given column idx (new name for lowestOne())
+ index _get_max_index( index idx ) const {
+ return matrix[ idx ].empty() ? -1 : *matrix[ idx ].rbegin();
+ }
+
+ // removes the maximal index of a column
+ void _remove_max( index idx ) {
+ std::list< index >::iterator it = matrix[ idx ].end();
+ it--;
+ matrix[ idx ].erase( it );
+ }
+
+ // clears given column
+ void _clear( index idx ) {
+ matrix[ idx ].clear();
+ }
+
+ // syncronizes all data structures (essential for openmp stuff)
+ void _sync() {}
+
+ // adds column 'source' to column 'target'
+ void _add_to( index source, index target ) {
+ std::list< index >& source_col = matrix[ source ];
+ std::list< index >& target_col = matrix[ target ];
+ std::list< index > temp_col;
+ target_col.swap( temp_col );
+ std::set_symmetric_difference( temp_col.begin(), temp_col.end(),
+ source_col.begin(), source_col.end(),
+ std::back_inserter( target_col ) );
+ }
+
+ // finalizes given column
+ void _finalize( index idx ) {
+ }
+ };
+}
diff --git a/matching/include/phat/representations/vector_set.h b/matching/include/phat/representations/vector_set.h
new file mode 100644
index 0000000..6878a27
--- /dev/null
+++ b/matching/include/phat/representations/vector_set.h
@@ -0,0 +1,99 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+
+namespace phat {
+ class vector_set {
+
+ protected:
+ std::vector< dimension > dims;
+ std::vector< std::set< index > > matrix;
+
+ public:
+ // overall number of cells in boundary_matrix
+ index _get_num_cols() const {
+ return (index)matrix.size();
+ }
+ void _set_num_cols( index nr_of_columns ) {
+ dims.resize( nr_of_columns );
+ matrix.resize( nr_of_columns );
+ }
+
+ // dimension of given index
+ dimension _get_dim( index idx ) const {
+ return dims[ idx ];
+ }
+ void _set_dim( index idx, dimension dim ) {
+ dims[ idx ] = dim;
+ }
+
+ // replaces(!) content of 'col' with boundary of given index
+ void _get_col( index idx, column& col ) const {
+ col.clear();
+ col.reserve( matrix[idx].size() );
+ std::copy (matrix[idx].begin(), matrix[idx].end(), std::back_inserter(col) );
+ }
+ void _set_col( index idx, const column& col ) {
+ matrix[ idx ].clear();
+ matrix[ idx ].insert( col.begin(), col.end() );
+ }
+
+ // true iff boundary of given idx is empty
+ bool _is_empty( index idx ) const {
+ return matrix[ idx ].empty();
+ }
+
+ // largest row index of given column idx (new name for lowestOne())
+ index _get_max_index( index idx ) const {
+ return matrix[ idx ].empty() ? -1 : *matrix[ idx ].rbegin();
+ }
+
+ // removes the maximal index of a column
+ void _remove_max( index idx ) {
+ std::set< index >::iterator it = matrix[ idx ].end();
+ it--;
+ matrix[ idx ].erase( it );
+ }
+
+ // clears given column
+ void _clear( index idx ) {
+ matrix[ idx ].clear();
+ }
+
+ // syncronizes all data structures (essential for openmp stuff)
+ void _sync() {}
+
+ // adds column 'source' to column 'target'
+ void _add_to( index source, index target ) {
+ for( std::set< index >::iterator it = matrix[ source ].begin(); it != matrix[ source ].end(); it++ ) {
+ std::set< index >& col = matrix[ target ];
+ std::pair< std::set< index >::iterator, bool > result = col.insert( *it );
+ if( !result.second )
+ col.erase( result.first );
+ }
+ }
+
+ // finalizes given column
+ void _finalize( index idx ) {
+ }
+
+ };
+}
diff --git a/matching/include/phat/representations/vector_vector.h b/matching/include/phat/representations/vector_vector.h
new file mode 100644
index 0000000..f111d6b
--- /dev/null
+++ b/matching/include/phat/representations/vector_vector.h
@@ -0,0 +1,107 @@
+/* Copyright 2013 IST Austria
+ Contributed by: Ulrich Bauer, Michael Kerber, Jan Reininghaus
+
+ This file is part of PHAT.
+
+ PHAT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ PHAT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with PHAT. If not, see <http://www.gnu.org/licenses/>. */
+
+#pragma once
+
+#include <phat/helpers/misc.h>
+
+namespace phat {
+ class vector_vector {
+
+ protected:
+ std::vector< dimension > dims;
+ std::vector< column > matrix;
+
+ thread_local_storage< column > temp_column_buffer;
+
+ public:
+ // overall number of cells in boundary_matrix
+ index _get_num_cols() const {
+ return (index)matrix.size();
+ }
+ void _set_num_cols( index nr_of_columns ) {
+ dims.resize( nr_of_columns );
+ matrix.resize( nr_of_columns );
+ }
+
+ // dimension of given index
+ dimension _get_dim( index idx ) const {
+ return dims[ idx ];
+ }
+ void _set_dim( index idx, dimension dim ) {
+ dims[ idx ] = dim;
+ }
+
+ // replaces(!) content of 'col' with boundary of given index
+ void _get_col( index idx, column& col ) const {
+ col = matrix[ idx ];
+ }
+ void _set_col( index idx, const column& col ) {
+ matrix[ idx ] = col;
+ }
+
+ // true iff boundary of given idx is empty
+ bool _is_empty( index idx ) const {
+ return matrix[ idx ].empty();
+ }
+
+ // largest row index of given column idx (new name for lowestOne())
+ index _get_max_index( index idx ) const {
+ return matrix[ idx ].empty() ? -1 : matrix[ idx ].back();
+ }
+
+ // removes the maximal index of a column
+ void _remove_max( index idx ) {
+ matrix[ idx ].pop_back();
+ }
+
+ // clears given column
+ void _clear( index idx ) {
+ matrix[ idx ].clear();
+ }
+
+ // syncronizes all data structures (essential for openmp stuff)
+ void _sync() {}
+
+ // adds column 'source' to column 'target'
+ void _add_to( index source, index target ) {
+ column& source_col = matrix[ source ];
+ column& target_col = matrix[ target ];
+ column& temp_col = temp_column_buffer();
+
+
+ size_t new_size = source_col.size() + target_col.size();
+
+ if (new_size > temp_col.size()) temp_col.resize(new_size);
+
+ std::vector<index>::iterator col_end = std::set_symmetric_difference( target_col.begin(), target_col.end(),
+ source_col.begin(), source_col.end(),
+ temp_col.begin() );
+ temp_col.erase(col_end, temp_col.end());
+
+
+ target_col.swap(temp_col);
+ }
+
+ // finalizes given column
+ void _finalize( index idx ) {
+ column& col = matrix[ idx ];
+ column(col.begin(), col.end()).swap(col);
+ }
+ };
+}
diff --git a/matching/include/simplex.h b/matching/include/simplex.h
new file mode 100644
index 0000000..75bbcae
--- /dev/null
+++ b/matching/include/simplex.h
@@ -0,0 +1,163 @@
+#ifndef MATCHING_DISTANCE_SIMPLEX_H
+#define MATCHING_DISTANCE_SIMPLEX_H
+
+#include <algorithm>
+#include <vector>
+#include <ostream>
+
+#include "common_util.h"
+
+namespace md {
+
+ template<class Real>
+ class Bifiltration;
+
+ enum class BifiltrationFormat {
+ phat_like, rivet
+ };
+
+ class AbstractSimplex {
+ private:
+ std::vector<int> vertices_;
+ public:
+
+ // this member is for convenience only;
+ // abstract simplices are identified by their set of vertices
+ mutable int id {-1};
+
+ decltype(auto) begin() { return vertices_.begin(); }
+
+ decltype(auto) end() { return vertices_.end(); }
+
+ decltype(auto) begin() const { return vertices_.begin(); }
+
+ decltype(auto) end() const { return vertices_.end(); }
+
+ decltype(auto) cbegin() const { return vertices_.cbegin(); }
+
+ decltype(auto) cend() const { return vertices_.cend(); }
+
+ int dim() const { return vertices_.size() - 1; }
+
+ void push_back(int v)
+ {
+ vertices_.push_back(v);
+ std::sort(vertices_.begin(), vertices_.end());
+ }
+
+ AbstractSimplex() { }
+
+ AbstractSimplex(std::vector<int> vertices, bool sort = true)
+ :vertices_(vertices)
+ {
+ if (sort)
+ std::sort(vertices_.begin(), vertices_.end());
+ }
+
+
+ template<class Iter>
+ AbstractSimplex(Iter beg_iter, Iter end_iter, bool sort = true)
+ :
+ vertices_(beg_iter, end_iter)
+ {
+ if (sort)
+ std::sort(vertices_.begin(), end());
+ }
+
+ std::vector<AbstractSimplex> facets() const
+ {
+ std::vector<AbstractSimplex> result;
+ for (int i = 0; i < static_cast<int>(vertices_.size()); ++i) {
+ std::vector<int> facet_vertices;
+ facet_vertices.reserve(dim());
+ for (int j = 0; j < static_cast<int>(vertices_.size()); ++j) {
+ if (j != i)
+ facet_vertices.push_back(vertices_[j]);
+ }
+ if (!facet_vertices.empty()) {
+ result.emplace_back(facet_vertices, false);
+ }
+ }
+ return result;
+ }
+
+ friend std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s);
+ // compare by vertices_ only
+ friend bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2);
+
+ friend bool operator<(const AbstractSimplex&, const AbstractSimplex&);
+ };
+
+ inline std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s)
+ {
+ os << "AbstractSimplex(id = " << s.id << ", vertices_ = " << container_to_string(s.vertices_) << ")";
+ return os;
+ }
+
+ inline bool operator<(const AbstractSimplex& a, const AbstractSimplex& b)
+ {
+ return a.vertices_ < b.vertices_;
+ }
+
+ inline bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2)
+ {
+ return s1.vertices_ == s2.vertices_;
+ }
+
+ template<class Real>
+ class Simplex {
+ private:
+ Index id_;
+ Point<Real> pos_;
+ int dim_;
+ // in our format we use facet indices,
+ // this is the fastest representation for homology
+ // Rivet format fills vertices_ vector
+ // Simplex alone cannot convert from one representation to the other,
+ // conversion routines are in Bifiltration
+ Column facet_indices_;
+ Column vertices_;
+ Real v {0}; // used when constructed a filtration for a slice
+ public:
+ Simplex(Index _id, std::string s, BifiltrationFormat input_format);
+
+ Simplex(Index _id, Point<Real> birth, int _dim, const Column& _bdry);
+
+ void init_rivet(std::string s);
+
+ void init_phat_like(std::string s);
+
+ Index id() const { return id_; }
+
+ int dim() const { return dim_; }
+
+ Column boundary() const { return facet_indices_; }
+
+ Real value() const { return v; }
+
+ // assumes 1-criticality
+ Point<Real> position() const { return pos_; }
+
+ void set_position(const Point<Real>& new_pos) { pos_ = new_pos; }
+
+ void scale(Real lambda)
+ {
+ pos_.x *= lambda;
+ pos_.y *= lambda;
+ }
+
+ void translate(Real a);
+
+ void set_value(Real new_val) { v = new_val; }
+
+ friend Bifiltration<Real>;
+ };
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Simplex<Real>& s);
+
+}
+
+#include "simplex.hpp"
+
+#endif //MATCHING_DISTANCE_SIMPLEX_H
diff --git a/matching/include/simplex.hpp b/matching/include/simplex.hpp
new file mode 100644
index 0000000..ce0e30f
--- /dev/null
+++ b/matching/include/simplex.hpp
@@ -0,0 +1,79 @@
+namespace md {
+
+ template<class Real>
+ Simplex<Real>::Simplex(Index id, Point<Real> birth, int dim, const Column& bdry)
+ :
+ id_(id),
+ pos_(birth),
+ dim_(dim),
+ facet_indices_(bdry) { }
+
+ template<class Real>
+ void Simplex<Real>::translate(Real a)
+ {
+ pos_.translate(a);
+ }
+
+ template<class Real>
+ void Simplex<Real>::init_rivet(std::string s)
+ {
+ auto delim_pos = s.find_first_of(";");
+ assert(delim_pos > 0);
+ std::string vertices_str = s.substr(0, delim_pos);
+ std::string pos_str = s.substr(delim_pos + 1);
+ assert(not vertices_str.empty() and not pos_str.empty());
+ // get vertices
+ std::stringstream vertices_ss(vertices_str);
+ int dim = 0;
+ int vertex;
+ while (vertices_ss >> vertex) {
+ dim++;
+ vertices_.push_back(vertex);
+ }
+ //
+ std::sort(vertices_.begin(), vertices_.end());
+ assert(dim > 0);
+
+ std::stringstream pos_ss(pos_str);
+ // TODO: get rid of 1-criticaltiy assumption
+ pos_ss >> pos_.x >> pos_.y;
+ }
+
+ template<class Real>
+ void Simplex<Real>::init_phat_like(std::string s)
+ {
+ facet_indices_.clear();
+ std::stringstream ss(s);
+ ss >> dim_ >> pos_.x >> pos_.y;
+ if (dim_ > 0) {
+ facet_indices_.reserve(dim_ + 1);
+ for (int j = 0; j <= dim_; j++) {
+ Index k;
+ ss >> k;
+ facet_indices_.push_back(k);
+ }
+ }
+ }
+
+ template<class Real>
+ Simplex<Real>::Simplex(Index _id, std::string s, BifiltrationFormat input_format)
+ :id_(_id)
+ {
+ switch (input_format) {
+ case BifiltrationFormat::phat_like :
+ init_phat_like(s);
+ break;
+ case BifiltrationFormat::rivet :
+ init_rivet(s);
+ break;
+ }
+ }
+
+ template<class Real>
+ std::ostream& operator<<(std::ostream& os, const Simplex<Real>& x)
+ {
+ os << "Simplex<Real>(id = " << x.id() << ", dim = " << x.dim();
+ os << ", boundary = " << container_to_string(x.boundary()) << ", pos = " << x.position() << ")";
+ return os;
+ }
+}
diff --git a/matching/tests/prism_1.bif b/matching/tests/prism_1.bif
new file mode 100644
index 0000000..b37e807
--- /dev/null
+++ b/matching/tests/prism_1.bif
@@ -0,0 +1,28 @@
+bifiltration_phat_like
+26
+0 0 0
+0 0 0
+0 0 0
+0 0 0
+0 0 0
+0 0 0
+1 1 0 1 0
+1 1 0 2 1
+1 0 0 3 1
+1 0 1 5 4
+1 0 0 4 1
+1 0 0 3 0
+1 0 0 5 0
+1 0 0 4 2
+1 0 0 5 2
+1 0 1 4 3
+1 1 0 2 0
+1 0 1 5 3
+2 1 1 13 10 7
+2 1 1 9 14 13
+2 1 1 8 11 6
+2 1 1 15 10 8
+2 1 1 14 12 16
+2 1 1 17 12 11
+2 4 0 7 16 6
+2 0 4 9 17 15
diff --git a/matching/tests/prism_2.bif b/matching/tests/prism_2.bif
new file mode 100644
index 0000000..49885e6
--- /dev/null
+++ b/matching/tests/prism_2.bif
@@ -0,0 +1,29 @@
+bifiltration_phat_like
+27
+0 0 0
+0 0 0
+0 0 0
+0 0 0
+0 0 0
+0 0 0
+1 0 0 2 1
+1 0 0 1 0
+1 0 0 5 3
+1 0 0 3 1
+1 0 0 5 4
+1 0 0 2 0
+1 0 0 4 1
+1 0 0 3 2
+1 0 0 5 2
+1 0 0 4 3
+1 0 0 4 2
+1 0 0 3 0
+2 0 0 16 12 6
+2 0 0 10 14 16
+2 0 0 9 17 7
+2 0 0 15 12 9
+2 0 0 13 17 11
+2 0 0 8 14 13
+2 4 0 6 11 7
+2 0 4 10 8 15
+2 3 3 15 16 13
diff --git a/matching/tests/test_bifiltration.cpp b/matching/tests/test_bifiltration.cpp
new file mode 100644
index 0000000..742dab8
--- /dev/null
+++ b/matching/tests/test_bifiltration.cpp
@@ -0,0 +1,36 @@
+#include "catch/catch.hpp"
+
+#include <sstream>
+#include <iostream>
+
+#include "common_util.h"
+#include "box.h"
+#include "bifiltration.h"
+
+using namespace md;
+
+//TEST_CASE("Small check", "[bifiltration][dim2]")
+//{
+// Bifiltration bif("/home/narn/code/matching_distance/code/src/tests/test_bifiltration_full_triangle_phat_like.txt", BifiltrationFormat::phat_like);
+// auto simplices = bif.simplices();
+// bif.sanity_check();
+//
+// REQUIRE( simplices.size() == 7 );
+//
+// REQUIRE( simplices[0].dim() == 0 );
+// REQUIRE( simplices[1].dim() == 0 );
+// REQUIRE( simplices[2].dim() == 0 );
+// REQUIRE( simplices[3].dim() == 1 );
+// REQUIRE( simplices[4].dim() == 1 );
+// REQUIRE( simplices[5].dim() == 1 );
+// REQUIRE( simplices[6].dim() == 2);
+//
+// REQUIRE( simplices[0].position() == Point(0, 0));
+// REQUIRE( simplices[1].position() == Point(0, 0));
+// REQUIRE( simplices[2].position() == Point(0, 0));
+// REQUIRE( simplices[3].position() == Point(3, 1));
+// REQUIRE( simplices[6].position() == Point(30, 40));
+//
+// Line line_1(Line::pi / 2.0, 0.0);
+// auto dgm = bif.slice_diagram(line_1);
+//}
diff --git a/matching/tests/test_common.cpp b/matching/tests/test_common.cpp
new file mode 100644
index 0000000..9079a56
--- /dev/null
+++ b/matching/tests/test_common.cpp
@@ -0,0 +1,156 @@
+#include "catch/catch.hpp"
+
+#include <sstream>
+#include <iostream>
+#include <string>
+
+#include "common_util.h"
+#include "simplex.h"
+#include "matching_distance.h"
+
+//using namespace md;
+using Real = double;
+using Point = md::Point<Real>;
+using Bifiltration = md::Bifiltration<Real>;
+using BifiltrationProxy = md::BifiltrationProxy<Real>;
+using CalculationParams = md::CalculationParams<Real>;
+using CellWithValue = md::CellWithValue<Real>;
+using DualPoint = md::DualPoint<Real>;
+using DualBox = md::DualBox<Real>;
+using Simplex = md::Simplex<Real>;
+using AbstractSimplex = md::AbstractSimplex;
+using BoundStrategy = md::BoundStrategy;
+using TraverseStrategy = md::TraverseStrategy;
+using AxisType = md::AxisType;
+using AngleType = md::AngleType;
+using ValuePoint = md::ValuePoint;
+using Column = md::Column;
+
+
+TEST_CASE("AbstractSimplex", "[abstract_simplex]")
+{
+ AbstractSimplex as;
+ REQUIRE(as.dim() == -1);
+
+ as.push_back(1);
+ REQUIRE(as.dim() == 0);
+ REQUIRE(as.facets().size() == 0);
+
+ as.push_back(0);
+ REQUIRE(as.dim() == 1);
+ REQUIRE(as.facets().size() == 2);
+ REQUIRE(as.facets()[0].dim() == 0);
+ REQUIRE(as.facets()[1].dim() == 0);
+
+}
+
+TEST_CASE("Vertical line", "[vertical_line]")
+{
+ // line x = 1
+ DualPoint l_vertical(AxisType::x_type, AngleType::steep, 0, 1);
+
+ REQUIRE(l_vertical.is_vertical());
+ REQUIRE(l_vertical.is_steep());
+
+ Point p_1(0.5, 0.5);
+ Point p_2(1.5, 0.5);
+ Point p_3(1.5, 1.5);
+ Point p_4(0.5, 1.5);
+ Point p_5(1, 10);
+
+ REQUIRE(l_vertical.x_from_y(10) == 1);
+ REQUIRE(l_vertical.x_from_y(-10) == 1);
+ REQUIRE(l_vertical.x_from_y(0) == 1);
+
+ REQUIRE(not l_vertical.contains(p_1));
+ REQUIRE(not l_vertical.contains(p_2));
+ REQUIRE(not l_vertical.contains(p_3));
+ REQUIRE(not l_vertical.contains(p_4));
+ REQUIRE(l_vertical.contains(p_5));
+
+ REQUIRE(l_vertical.goes_below(p_1));
+ REQUIRE(not l_vertical.goes_below(p_2));
+ REQUIRE(not l_vertical.goes_below(p_3));
+ REQUIRE(l_vertical.goes_below(p_4));
+
+ REQUIRE(not l_vertical.goes_above(p_1));
+ REQUIRE(l_vertical.goes_above(p_2));
+ REQUIRE(l_vertical.goes_above(p_3));
+ REQUIRE(not l_vertical.goes_above(p_4));
+
+}
+
+TEST_CASE("Horizontal line", "[horizontal_line]")
+{
+ // line y = 1
+ DualPoint l_horizontal(AxisType::y_type, AngleType::flat, 0, 1);
+
+ REQUIRE(l_horizontal.is_horizontal());
+ REQUIRE(l_horizontal.is_flat());
+ REQUIRE(l_horizontal.y_slope() == 0);
+ REQUIRE(l_horizontal.y_intercept() == 1);
+
+ Point p_1(0.5, 0.5);
+ Point p_2(1.5, 0.5);
+ Point p_3(1.5, 1.5);
+ Point p_4(0.5, 1.5);
+ Point p_5(2, 1);
+
+ REQUIRE((not l_horizontal.contains(p_1) and
+ not l_horizontal.contains(p_2) and
+ not l_horizontal.contains(p_3) and
+ not l_horizontal.contains(p_4) and
+ l_horizontal.contains(p_5)));
+
+ REQUIRE(not l_horizontal.goes_below(p_1));
+ REQUIRE(not l_horizontal.goes_below(p_2));
+ REQUIRE(l_horizontal.goes_below(p_3));
+ REQUIRE(l_horizontal.goes_below(p_4));
+ REQUIRE(l_horizontal.goes_below(p_5));
+
+ REQUIRE(l_horizontal.goes_above(p_1));
+ REQUIRE(l_horizontal.goes_above(p_2));
+ REQUIRE(not l_horizontal.goes_above(p_3));
+ REQUIRE(not l_horizontal.goes_above(p_4));
+ REQUIRE(l_horizontal.goes_above(p_5));
+}
+
+TEST_CASE("Flat Line with positive slope", "[flat_line]")
+{
+ // line y = x / 2 + 1
+ DualPoint l_flat(AxisType::y_type, AngleType::flat, 0.5, 1);
+
+ REQUIRE(not l_flat.is_horizontal());
+ REQUIRE(l_flat.is_flat());
+ REQUIRE(l_flat.y_slope() == 0.5);
+ REQUIRE(l_flat.y_intercept() == 1);
+
+ REQUIRE(l_flat.y_from_x(0) == 1);
+ REQUIRE(l_flat.y_from_x(1) == 1.5);
+ REQUIRE(l_flat.y_from_x(2) == 2);
+
+ Point p_1(3, 2);
+ Point p_2(-2, 0.01);
+ Point p_3(0, 1.25);
+ Point p_4(6, 4.5);
+ Point p_5(2, 2);
+
+ REQUIRE((not l_flat.contains(p_1) and
+ not l_flat.contains(p_2) and
+ not l_flat.contains(p_3) and
+ not l_flat.contains(p_4) and
+ l_flat.contains(p_5)));
+
+ REQUIRE(not l_flat.goes_below(p_1));
+ REQUIRE(l_flat.goes_below(p_2));
+ REQUIRE(l_flat.goes_below(p_3));
+ REQUIRE(l_flat.goes_below(p_4));
+ REQUIRE(l_flat.goes_below(p_5));
+
+ REQUIRE(l_flat.goes_above(p_1));
+ REQUIRE(not l_flat.goes_above(p_2));
+ REQUIRE(not l_flat.goes_above(p_3));
+ REQUIRE(not l_flat.goes_above(p_4));
+ REQUIRE(l_flat.goes_above(p_5));
+
+}
diff --git a/matching/tests/test_list.txt b/matching/tests/test_list.txt
new file mode 100644
index 0000000..28ee48b
--- /dev/null
+++ b/matching/tests/test_list.txt
@@ -0,0 +1 @@
+prism_1.bif prism_2.bif 1.0
diff --git a/matching/tests/test_matching_distance.cpp b/matching/tests/test_matching_distance.cpp
new file mode 100644
index 0000000..115c8d9
--- /dev/null
+++ b/matching/tests/test_matching_distance.cpp
@@ -0,0 +1,159 @@
+#include "catch/catch.hpp"
+
+#include <sstream>
+#include <iostream>
+#include <string>
+
+#define MD_TEST_CODE
+
+#include "common_util.h"
+#include "simplex.h"
+#include "matching_distance.h"
+
+using Real = double;
+using Point = md::Point<Real>;
+using Bifiltration = md::Bifiltration<Real>;
+using BifiltrationProxy = md::BifiltrationProxy<Real>;
+using CalculationParams = md::CalculationParams<Real>;
+using CellWithValue = md::CellWithValue<Real>;
+using DualPoint = md::DualPoint<Real>;
+using DualBox = md::DualBox<Real>;
+using Simplex = md::Simplex<Real>;
+using AbstractSimplex = md::AbstractSimplex;
+using BoundStrategy = md::BoundStrategy;
+using TraverseStrategy = md::TraverseStrategy;
+using AxisType = md::AxisType;
+using AngleType = md::AngleType;
+using ValuePoint = md::ValuePoint;
+using Column = md::Column;
+
+using md::k_corner_vps;
+
+TEST_CASE("Different bounds", "[bounds]")
+{
+ std::vector<Simplex> simplices;
+ std::vector<Point> points;
+
+ Real max_x = 10;
+ Real max_y = 20;
+
+ int simplex_id = 0;
+ for(int i = 0; i <= max_x; ++i) {
+ for(int j = 0; j <= max_y; ++j) {
+ Point p(i, j);
+ simplices.emplace_back(simplex_id++, p, 0, Column());
+ points.push_back(p);
+ }
+ }
+
+ Bifiltration bif_a(simplices.begin(), simplices.end());
+ Bifiltration bif_b(simplices.begin(), simplices.end());
+
+ CalculationParams params;
+ params.initialization_depth = 2;
+
+ BifiltrationProxy bifp_a(bif_a, params.dim);
+ BifiltrationProxy bifp_b(bif_b, params.dim);
+
+ md::DistanceCalculator<Real, BifiltrationProxy> calc(bifp_a, bifp_b, params);
+
+// REQUIRE(calc.max_x_ == Approx(max_x));
+// REQUIRE(calc.max_y_ == Approx(max_y));
+
+ std::vector<DualBox> boxes;
+
+ for(CellWithValue c : calc.get_refined_grid(5, false, false)) {
+ boxes.push_back(c.dual_box());
+ }
+
+ // fill in boxes and points
+
+ for(DualBox db : boxes) {
+ Real local_bound = calc.get_local_dual_bound(db);
+ Real local_bound_refined = calc.get_local_refined_bound(db);
+ REQUIRE(local_bound >= local_bound_refined);
+ for(Point p : points) {
+ for(ValuePoint vp_a : k_corner_vps) {
+ CellWithValue dual_cell(db, 1);
+ DualPoint corner_a = dual_cell.value_point(vp_a);
+ Real wp_a = corner_a.weighted_push(p);
+ dual_cell.set_value_at(vp_a, wp_a);
+ Real point_bound = calc.get_max_displacement_single_point(dual_cell, vp_a, p);
+ for(ValuePoint vp_b : k_corner_vps) {
+ if (vp_b <= vp_a)
+ continue;
+ DualPoint corner_b = dual_cell.value_point(vp_b);
+ Real wp_b = corner_b.weighted_push(p);
+ Real diff = fabs(wp_a - wp_b);
+ if (not(point_bound <= Approx(local_bound_refined))) {
+ std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound
+ << ", refined local = " << local_bound_refined << std::endl;
+ calc.get_max_displacement_single_point(dual_cell, vp_a, p);
+ calc.get_local_refined_bound(db);
+ }
+
+ REQUIRE(point_bound <= Approx(local_bound_refined));
+ REQUIRE(diff <= Approx(point_bound));
+ REQUIRE(diff <= Approx(local_bound_refined));
+ }
+
+ for(DualPoint l_random : db.random_points(100)) {
+ Real wp_random = l_random.weighted_push(p);
+ Real diff = fabs(wp_a - wp_random);
+ if (not(diff <= Approx(point_bound))) {
+ if (db.critical_points(p).size() > 4) {
+ std::cerr << "ERROR interesting case" << std::endl;
+ } else {
+ std::cerr << "ERROR boring case" << std::endl;
+ }
+ l_random.weighted_push(p);
+ std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound
+ << ", refined local = " << local_bound_refined;
+ std::cerr << ", random_dual_point = " << l_random << ", wp_random = " << wp_random
+ << ", diff = " << diff << std::endl;
+ }
+ REQUIRE(diff <= Approx(point_bound));
+ }
+ }
+ }
+ }
+}
+
+TEST_CASE("Bifiltrations from file", "[matching_distance][small_example]")
+{
+ std::string fname_a, fname_b;
+
+ fname_a = "../tests/prism_1.bif";
+ fname_b = "../tests/prism_2.bif";
+
+ Bifiltration bif_a(fname_a);
+ Bifiltration bif_b(fname_b);
+
+ CalculationParams params;
+
+ std::vector<BoundStrategy> bound_strategies {BoundStrategy::local_combined,
+ BoundStrategy::local_dual_bound_refined};
+
+ std::vector<TraverseStrategy> traverse_strategies {TraverseStrategy::breadth_first, TraverseStrategy::depth_first};
+
+ std::vector<double> scaling_factors {10, 1.0};
+
+ for(auto bs : bound_strategies) {
+ for(auto ts : traverse_strategies) {
+ for(double lambda : scaling_factors) {
+ Bifiltration bif_a_copy(bif_a);
+ Bifiltration bif_b_copy(bif_b);
+ bif_a_copy.scale(lambda);
+ bif_b_copy.scale(lambda);
+ params.bound_strategy = bs;
+ params.traverse_strategy = ts;
+ params.max_depth = 7;
+ params.delta = 0.1;
+ params.dim = 1;
+ Real answer = matching_distance(bif_a_copy, bif_b_copy, params);
+ Real correct_answer = lambda * 1.0;
+ REQUIRE(fabs(answer - correct_answer) / correct_answer < params.delta);
+ }
+ }
+ }
+}
diff --git a/matching/tests/test_module.cpp b/matching/tests/test_module.cpp
new file mode 100644
index 0000000..bae1f78
--- /dev/null
+++ b/matching/tests/test_module.cpp
@@ -0,0 +1,109 @@
+#include "catch/catch.hpp"
+
+#include <sstream>
+#include <iostream>
+#include <string>
+
+#define MD_TEST_CODE
+
+#include "common_util.h"
+#include "persistence_module.h"
+#include "matching_distance.h"
+
+using Real = double;
+using Point = md::Point<Real>;
+using Bifiltration = md::Bifiltration<Real>;
+using BifiltrationProxy = md::BifiltrationProxy<Real>;
+using CalculationParams = md::CalculationParams<Real>;
+using CellWithValue = md::CellWithValue<Real>;
+using DualPoint = md::DualPoint<Real>;
+using DualBox = md::DualBox<Real>;
+using BoundStrategy = md::BoundStrategy;
+using TraverseStrategy = md::TraverseStrategy;
+using AxisType = md::AxisType;
+using AngleType = md::AngleType;
+using ValuePoint = md::ValuePoint;
+using Column = md::Column;
+using PointVec = md::PointVec<Real>;
+using Module = md::ModulePresentation<Real>;
+using Relation = Module::Relation;
+using RelationVec = Module::RelVec;
+using IndexVec = md::IndexVec;
+
+using md::k_corner_vps;
+
+TEST_CASE("Module projection", "[module][projection]")
+{
+ PointVec gens;
+ gens.emplace_back(1, 1); // A
+ gens.emplace_back(2, 3); // B
+ gens.emplace_back(3, 2); // C
+
+ RelationVec rels;
+
+ Point rel_x_position { 3.98, 2.47 };
+ IndexVec rel_x_components { 0, 2 }; // X: A + C = 0
+
+ Point rel_y_position { 2.5, 4 };
+ IndexVec rel_y_components { 0, 1 }; // Y: A + B = 0
+
+ Point rel_z_position { 5, 5 };
+ IndexVec rel_z_components { 1 }; // Z: B = 0
+
+
+ rels.emplace_back(rel_x_position, rel_x_components);
+ rels.emplace_back(rel_y_position, rel_y_components);
+ rels.emplace_back(rel_z_position, rel_z_components);
+
+ Module module { gens, rels };
+
+ {
+ DualPoint slice_1(AxisType::x_type, AngleType::flat, 6.0 / 7.0, 3.0);
+ std::vector<Real> gen_ps_1, rel_ps_1;
+ phat::boundary_matrix<> matr_1;
+
+ module.get_slice_projection_matrix(slice_1, matr_1, gen_ps_1, rel_ps_1);
+
+ phat::column c_1_0, c_1_1, c_1_2;
+
+ matr_1.get_col(0, c_1_0);
+ matr_1.get_col(1, c_1_1);
+ matr_1.get_col(2, c_1_2);
+
+ phat::column c_1_0_correct { 0, 1};
+ phat::column c_1_1_correct { 0, 2};
+ phat::column c_1_2_correct { 2 };
+
+ REQUIRE(c_1_0 == c_1_0_correct);
+ REQUIRE(c_1_1 == c_1_1_correct);
+ REQUIRE(c_1_2 == c_1_2_correct);
+ }
+
+ {
+
+ DualPoint slice_2(AxisType::y_type, AngleType::flat, 2.0 / 9.0, 5.0);
+ std::vector<Real> gen_ps_2, rel_ps_2;
+ phat::boundary_matrix<> matr_2;
+
+ module.get_slice_projection_matrix(slice_2, matr_2, gen_ps_2, rel_ps_2);
+
+ phat::column c_2_0, c_2_1, c_2_2;
+
+ matr_2.get_col(0, c_2_0);
+ matr_2.get_col(1, c_2_1);
+ matr_2.get_col(2, c_2_2);
+
+ phat::column c_2_0_correct { 0, 1};
+ phat::column c_2_1_correct { 0, 2};
+ phat::column c_2_2_correct { 1 };
+
+ //std::cerr << "gen_ps_2: " << md::container_to_string(gen_ps_2) << std::endl;
+ //std::cerr << "rel_ps_2: " << md::container_to_string(rel_ps_2) << std::endl;
+
+ REQUIRE(c_2_0 == c_2_0_correct);
+ REQUIRE(c_2_1 == c_2_1_correct);
+ REQUIRE(c_2_2 == c_2_2_correct);
+ }
+
+
+}
diff --git a/matching/tests/tests_main.cpp b/matching/tests/tests_main.cpp
new file mode 100644
index 0000000..1c77b13
--- /dev/null
+++ b/matching/tests/tests_main.cpp
@@ -0,0 +1,2 @@
+#define CATCH_CONFIG_MAIN
+#include "catch/catch.hpp"