From 6942d80c4d49239bca9cace9833aa74aee11ddcb Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Tue, 3 Dec 2019 21:14:03 +0100 Subject: Add matching distance code. --- matching/CMakeLists.txt | 112 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 matching/CMakeLists.txt (limited to 'matching/CMakeLists.txt') diff --git a/matching/CMakeLists.txt b/matching/CMakeLists.txt new file mode 100644 index 0000000..9384328 --- /dev/null +++ b/matching/CMakeLists.txt @@ -0,0 +1,112 @@ +project(matching_distance) +cmake_minimum_required(VERSION 3.5.1) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +include(GenerateExportHeader) +find_package(Boost REQUIRED) + +# Default to Release + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel." FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif (NOT CMAKE_BUILD_TYPE) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include + SYSTEM ${BOOST_INCLUDE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include) + +set(CMAKE_CXX_STANDARD 14) + + +if (NOT WIN32) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS} -Wall pedantic -Wextra ") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS} -Wall -Wextra ") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -ggdb -D_GLIBCXX_DEBUG") + set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELEASE} -O2 -g -ggdb") +endif (NOT WIN32) + +file(GLOB BT_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.hpp) +file(GLOB MD_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.h) + +file(GLOB SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) + +file(GLOB SRC_TEST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/tests/*.cpp) + +find_package(Threads) +set(libraries ${libraries} "stdc++fs" ${CMAKE_THREAD_LIBS_INIT} ${OpenMP_CXX_LIBRARIES}) + + +find_package(OpenMP) +if (OPENMP_FOUND) +set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") +set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") +set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") +endif() + + +#add_executable(matching_distance ${SRC_FILES} ${BT_HEADERS} ${MD_HEADERS}) +add_executable(matching_distance "src/main.cpp" + "src/box.cpp" + "src/common_util.cpp" + "src/persistence_module.cpp" + "src/simplex.cpp" + "src/bifiltration.cpp" + "src/matching_distance.cpp" + "src/dual_box.cpp" + "src/dual_point.cpp" + "include/box.h" + "include/common_util.h" + "include/persistence_module.h" + "include/simplex.h" + "include/bifiltration.h" + "include/matching_distance.h" + "include/dual_box.h" + "include/dual_point.h" + ${BT_HEADERS} include/cell_with_value.h src/cell_with_value.cpp) +target_link_libraries(matching_distance PUBLIC ${libraries}) + +#add_executable(matching_distance_test ${SRC_TEST_FILES} ${BT_HEADERS} ${MD_HEADERS}) +add_executable(matching_distance_test ${SRC_TEST_FILES} + "src/box.cpp" + "src/common_util.cpp" + "src/persistence_module.cpp" + "src/simplex.cpp" + "src/bifiltration.cpp" + "src/matching_distance.cpp" + "src/dual_box.cpp" + "src/dual_point.cpp" + "include/box.h" + "include/common_util.h" + "include/persistence_module.h" + "include/simplex.h" + "include/bifiltration.h" + "include/matching_distance.h" + "include/dual_box.h" + "include/dual_point.h" + ${BT_HEADERS} src/tests/test_common.cpp src/common_util.cpp src/tests/test_matching_distance.cpp src/cell_with_value.cpp) +target_link_libraries(matching_distance_test PUBLIC ${libraries}) + +add_executable(test_generator "src/test_generator.cpp" + "src/box.cpp" + "src/common_util.cpp" + "src/persistence_module.cpp" + "src/simplex.cpp" + "src/bifiltration.cpp" + "src/matching_distance.cpp" + "src/dual_box.cpp" + "src/dual_point.cpp" + "include/box.h" + "include/common_util.h" + "include/persistence_module.h" + "include/simplex.h" + "include/bifiltration.h" + "include/matching_distance.h" + "include/dual_box.h" + "include/dual_point.h" + ${BT_HEADERS} src/cell_with_value.cpp) +target_link_libraries(test_generator PUBLIC ${libraries}) + +#add_executable(matching_distance "src/main.cpp" "src/box.cpp" "src/common_util.cpp" "src/line.cpp" "src/persistence_module.cpp" ${BT_HEADERS} ${MD_HEADERS}) -- cgit v1.2.3 From 3809e4071827a5959f27e472514eaed08ba6d15e Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Wed, 4 Mar 2020 00:33:51 +0100 Subject: Make matching distance header-only. --- matching/CMakeLists.txt | 66 +--- matching/include/bifiltration.h | 72 ++--- matching/include/bifiltration.hpp | 421 ++++++++++++++++++++++++++ matching/include/box.h | 77 ++--- matching/include/box.hpp | 52 ++++ matching/include/cell_with_value.h | 47 ++- matching/include/cell_with_value.hpp | 224 ++++++++++++++ matching/include/common_defs.h | 4 +- matching/include/common_util.h | 113 +++---- matching/include/common_util.hpp | 96 ++++++ matching/include/dual_box.h | 78 ++--- matching/include/dual_box.hpp | 190 ++++++++++++ matching/include/dual_point.h | 27 +- matching/include/dual_point.hpp | 299 ++++++++++++++++++ matching/include/matching_distance.h | 203 ++++++++++--- matching/include/matching_distance.hpp | 326 +++++++++++--------- matching/include/persistence_module.h | 34 ++- matching/include/persistence_module.hpp | 177 +++++++++++ matching/include/simplex.h | 70 ++++- matching/include/simplex.hpp | 79 +++++ matching/src/bifiltration.cpp | 407 ------------------------- matching/src/box.cpp | 61 ---- matching/src/cell_with_value.cpp | 247 --------------- matching/src/common_util.cpp | 243 --------------- matching/src/dual_box.cpp | 194 ------------ matching/src/dual_point.cpp | 282 ----------------- matching/src/main.cpp | 29 +- matching/src/matching_distance.cpp | 150 --------- matching/src/persistence_module.cpp | 177 ----------- matching/src/simplex.cpp | 121 -------- matching/src/test_generator.cpp | 19 +- matching/src/tests/test_common.cpp | 66 ++-- matching/src/tests/test_matching_distance.cpp | 22 +- 33 files changed, 2192 insertions(+), 2481 deletions(-) create mode 100644 matching/include/bifiltration.hpp create mode 100644 matching/include/box.hpp create mode 100644 matching/include/cell_with_value.hpp create mode 100644 matching/include/common_util.hpp create mode 100644 matching/include/dual_box.hpp create mode 100644 matching/include/dual_point.hpp create mode 100644 matching/include/persistence_module.hpp create mode 100644 matching/include/simplex.hpp delete mode 100644 matching/src/bifiltration.cpp delete mode 100644 matching/src/box.cpp delete mode 100644 matching/src/cell_with_value.cpp delete mode 100644 matching/src/common_util.cpp delete mode 100644 matching/src/dual_box.cpp delete mode 100644 matching/src/dual_point.cpp delete mode 100644 matching/src/matching_distance.cpp delete mode 100644 matching/src/persistence_module.cpp delete mode 100644 matching/src/simplex.cpp (limited to 'matching/CMakeLists.txt') diff --git a/matching/CMakeLists.txt b/matching/CMakeLists.txt index 9384328..121e25c 100644 --- a/matching/CMakeLists.txt +++ b/matching/CMakeLists.txt @@ -29,84 +29,34 @@ if (NOT WIN32) endif (NOT WIN32) file(GLOB BT_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.hpp) -file(GLOB MD_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.h) +file(GLOB MD_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp) file(GLOB SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) file(GLOB SRC_TEST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/tests/*.cpp) find_package(Threads) -set(libraries ${libraries} "stdc++fs" ${CMAKE_THREAD_LIBS_INIT} ${OpenMP_CXX_LIBRARIES}) +set(libraries ${libraries} "stdc++fs" ${CMAKE_THREAD_LIBS_INIT}) find_package(OpenMP) if (OPENMP_FOUND) +set(libraries ${libraries} ${OpenMP_CXX_LIBRARIES}) set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") endif() -#add_executable(matching_distance ${SRC_FILES} ${BT_HEADERS} ${MD_HEADERS}) -add_executable(matching_distance "src/main.cpp" - "src/box.cpp" - "src/common_util.cpp" - "src/persistence_module.cpp" - "src/simplex.cpp" - "src/bifiltration.cpp" - "src/matching_distance.cpp" - "src/dual_box.cpp" - "src/dual_point.cpp" - "include/box.h" - "include/common_util.h" - "include/persistence_module.h" - "include/simplex.h" - "include/bifiltration.h" - "include/matching_distance.h" - "include/dual_box.h" - "include/dual_point.h" - ${BT_HEADERS} include/cell_with_value.h src/cell_with_value.cpp) +add_executable(matching_distance "src/main.cpp" ${MD_HEADERS} ${BT_HEADERS} ) target_link_libraries(matching_distance PUBLIC ${libraries}) -#add_executable(matching_distance_test ${SRC_TEST_FILES} ${BT_HEADERS} ${MD_HEADERS}) -add_executable(matching_distance_test ${SRC_TEST_FILES} - "src/box.cpp" - "src/common_util.cpp" - "src/persistence_module.cpp" - "src/simplex.cpp" - "src/bifiltration.cpp" - "src/matching_distance.cpp" - "src/dual_box.cpp" - "src/dual_point.cpp" - "include/box.h" - "include/common_util.h" - "include/persistence_module.h" - "include/simplex.h" - "include/bifiltration.h" - "include/matching_distance.h" - "include/dual_box.h" - "include/dual_point.h" - ${BT_HEADERS} src/tests/test_common.cpp src/common_util.cpp src/tests/test_matching_distance.cpp src/cell_with_value.cpp) +add_executable(matching_distance_test ${SRC_TEST_FILES} ${BT_HEADERS} ${MD_HEADERS}) target_link_libraries(matching_distance_test PUBLIC ${libraries}) add_executable(test_generator "src/test_generator.cpp" - "src/box.cpp" - "src/common_util.cpp" - "src/persistence_module.cpp" - "src/simplex.cpp" - "src/bifiltration.cpp" - "src/matching_distance.cpp" - "src/dual_box.cpp" - "src/dual_point.cpp" - "include/box.h" - "include/common_util.h" - "include/persistence_module.h" - "include/simplex.h" - "include/bifiltration.h" - "include/matching_distance.h" - "include/dual_box.h" - "include/dual_point.h" - ${BT_HEADERS} src/cell_with_value.cpp) + ${MD_HEADERS} + ${BT_HEADERS}) target_link_libraries(test_generator PUBLIC ${libraries}) -#add_executable(matching_distance "src/main.cpp" "src/box.cpp" "src/common_util.cpp" "src/line.cpp" "src/persistence_module.cpp" ${BT_HEADERS} ${MD_HEADERS}) +#add_executable(matching_distance "include/main.cpp" "src/box.cpp" "src/common_util.cpp" "src/line.cpp" "src/persistence_module.hpp" ${BT_HEADERS} ${MD_HEADERS}) diff --git a/matching/include/bifiltration.h b/matching/include/bifiltration.h index f505ed9..4dd8662 100644 --- a/matching/include/bifiltration.h +++ b/matching/include/bifiltration.h @@ -3,19 +3,30 @@ #include #include +#include +#include +#include +#include #include "common_util.h" #include "box.h" #include "simplex.h" #include "dual_point.h" +#include "phat/boundary_matrix.h" +#include "phat/compute_persistence_pairs.h" + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/fmt.h" +#include "spdlog/fmt/ostr.h" + +#include "common_util.h" namespace md { + template class Bifiltration { public: - using Diagram = std::vector>; - using Box = md::Box; - using SimplexVector = std::vector; + using SimplexVector = std::vector>; Bifiltration() = default; @@ -36,7 +47,7 @@ namespace md { init(); } - Diagram weighted_slice_diagram(const DualPoint& line, int dim) const; + Diagram weighted_slice_diagram(const DualPoint& line, int dim) const; SimplexVector simplices() const { return simplices_; } @@ -48,14 +59,12 @@ namespace md { Real minimal_coordinate() const; // return box that contains positions of all simplices - Box bounding_box() const; + Box bounding_box() const; void sanity_check() const; int maximal_dim() const { return maximal_dim_; } - friend std::ostream& operator<<(std::ostream& os, const Bifiltration& bif); - Real max_x() const; Real max_y() const; @@ -64,7 +73,7 @@ namespace md { Real min_y() const; - void add_simplex(Index _id, Point birth, int _dim, const Column& _bdry); + void add_simplex(Index _id, Point birth, int _dim, const Column& _bdry); void save(const std::string& filename, BifiltrationFormat format = BifiltrationFormat::rivet); // save to file @@ -72,11 +81,8 @@ namespace md { private: SimplexVector simplices_; - // axes names, for rivet bifiltration format only - std::string parameter_1_name_ {"axis_1"}; - std::string parameter_2_name_ {"axis_2"}; - Box bounding_box_; + Box bounding_box_; int maximal_dim_ {-1}; void init(); @@ -97,13 +103,15 @@ namespace md { }; - std::ostream& operator<<(std::ostream& os, const Bifiltration& bif); + template + std::ostream& operator<<(std::ostream& os, const Bifiltration& bif); + template class BifiltrationProxy { public: - BifiltrationProxy(const Bifiltration& bif, int dim = 0); + BifiltrationProxy(const Bifiltration& bif, int dim = 0); // return critical values of simplices that are important for current dimension (dim and dim+1) - PointVec positions() const; + PointVec positions() const; // set current dimension int set_dim(int new_dim); @@ -111,46 +119,22 @@ namespace md { int maximal_dim() const; void translate(Real a); Real minimal_coordinate() const; - Box bounding_box() const; + Box bounding_box() const; Real max_x() const; Real max_y() const; Real min_x() const; Real min_y() const; - Diagram weighted_slice_diagram(const DualPoint& slice) const; + Diagram weighted_slice_diagram(const DualPoint& slice) const; private: int dim_ { 0 }; - mutable PointVec cached_positions_; - Bifiltration bif_; + mutable PointVec cached_positions_; + Bifiltration bif_; void cache_positions() const; }; } - +#include "bifiltration.hpp" #endif //MATCHING_DISTANCE_BIFILTRATION_H - -//// The value type of OutputIterator is Simplex_in_2D_filtration -//template -//void read_input(std::string filename, OutputIterator out) -//{ -// std::ifstream ifstr; -// ifstr.open(filename.c_str()); -// long n; -// ifstr >> n; // number of simplices is the first number in file -// -// Index k; // used in loop -// for (int i = 0; i < n; i++) { -// Simplex_in_2D_filtration next; -// next.index = i; -// ifstr >> next.dim >> next.pos.x >> next.pos.y; -// if (next.dim > 0) { -// for (int j = 0; j <= next.dim; j++) { -// ifstr >> k; -// next.bd.push_back(k); -// } -// } -// *out++ = next; -// } -//} diff --git a/matching/include/bifiltration.hpp b/matching/include/bifiltration.hpp new file mode 100644 index 0000000..9e2a82e --- /dev/null +++ b/matching/include/bifiltration.hpp @@ -0,0 +1,421 @@ +namespace md { + + template + void Bifiltration::init() + { + auto lower_left = max_point(); + auto upper_right = min_point(); + for(const auto& simplex : simplices_) { + lower_left = greatest_lower_bound<>(lower_left, simplex.position()); + upper_right = least_upper_bound<>(upper_right, simplex.position()); + maximal_dim_ = std::max(maximal_dim_, simplex.dim()); + } + bounding_box_ = Box(lower_left, upper_right); + } + + template + Bifiltration::Bifiltration(const std::string& fname) + { + std::ifstream ifstr {fname.c_str()}; + if (!ifstr.good()) { + std::string error_message = fmt::format("Cannot open file {0}", fname); + std::cerr << error_message << std::endl; + throw std::runtime_error(error_message); + } + + BifiltrationFormat input_format; + + std::string s; + + while(ignore_line(s)) { + std::getline(ifstr, s); + } + + if (s == "bifiltration") { + input_format = BifiltrationFormat::rivet; + } else if (s == "bifiltration_phat_like") { + input_format = BifiltrationFormat::phat_like; + } else { + std::cerr << "Unknown format: '" << s << "' in file " << fname << std::endl; + throw std::runtime_error("unknown bifiltration format"); + } + + switch(input_format) { + case BifiltrationFormat::rivet : + rivet_format_reader(ifstr); + break; + case BifiltrationFormat::phat_like : + phat_like_format_reader(ifstr); + break; + } + + ifstr.close(); + + init(); + } + + template + void Bifiltration::rivet_format_reader(std::ifstream& ifstr) + { + std::string s; + // read axes names, ignore them + std::getline(ifstr, s); + std::getline(ifstr, s); + + Index index = 0; + while(std::getline(ifstr, s)) { + if (!ignore_line(s)) { + simplices_.emplace_back(index++, s, BifiltrationFormat::rivet); + } + } + } + + template + void Bifiltration::phat_like_format_reader(std::ifstream& ifstr) + { + spd::debug("Enter phat_like_format_reader"); + // read stream line by line; do not use >> operator + std::string s; + std::getline(ifstr, s); + + // first line contains number of simplices + long n_simplices = std::stol(s); + + // all other lines represent a simplex + Index index = 0; + while(index < n_simplices) { + std::getline(ifstr, s); + if (!ignore_line(s)) { + simplices_.emplace_back(index++, s, BifiltrationFormat::phat_like); + } + } + spd::debug("Read {} simplices from file", n_simplices); + } + + template + void Bifiltration::scale(Real lambda) + { + for(auto& s : simplices_) { + s.scale(lambda); + } + init(); + } + + template + void Bifiltration::sanity_check() const + { +#ifdef DEBUG + spd::debug("Enter Bifiltration::sanity_check"); + // check that boundary has correct number of simplices, + // each bounding simplex has correct dim + // and appears in the filtration before the simplex it bounds + for(const auto& s : simplices_) { + assert(s.dim() >= 0); + assert(s.dim() == 0 or s.dim() + 1 == (int) s.boundary().size()); + for(auto bdry_idx : s.boundary()) { + Simplex bdry_simplex = simplices()[bdry_idx]; + assert(bdry_simplex.dim() == s.dim() - 1); + assert(bdry_simplex.position().is_less(s.position(), false)); + } + } + spd::debug("Exit Bifiltration::sanity_check"); +#endif + } + + template + Diagram Bifiltration::weighted_slice_diagram(const DualPoint& line, int dim) const + { + DiagramKeeper dgm; + + // make a copy for now; I want slice_diagram to be const + std::vector> simplices(simplices_); + +// std::vector simplices; +// simplices.reserve(simplices_.size() / 2); +// for(const auto& s : simplices_) { +// if (s.dim() <= dim + 1 and s.dim() >= dim) +// simplices.emplace_back(s); +// } + + spd::debug("Enter slice diagram, line = {}, simplices.size = {}", line, simplices.size()); + + for(auto& simplex : simplices) { + Real value = line.weighted_push(simplex.position()); +// spd::debug("in slice_diagram, simplex = {}, value = {}\n", simplex, value); + simplex.set_value(value); + } + + std::sort(simplices.begin(), simplices.end(), + [](const Simplex& a, const Simplex& b) { return a.value() < b.value(); }); + std::map index_map; + for(Index i = 0; i < (int) simplices.size(); i++) { + index_map[simplices[i].id()] = i; + } + + phat::boundary_matrix<> phat_matrix; + phat_matrix.set_num_cols(simplices.size()); + std::vector bd_in_slice_filtration; + for(Index i = 0; i < (int) simplices.size(); i++) { + phat_matrix.set_dim(i, simplices[i].dim()); + bd_in_slice_filtration.clear(); + //std::cout << "new col" << i << std::endl; + for(int j = 0; j < (int) simplices[i].boundary().size(); j++) { + // F[i] contains the indices of its facet wrt to the + // original filtration. We have to express it, however, + // wrt to the filtration along the slice. That is why + // we need the index_map + //std::cout << "Found " << F[i].bd[j] << ", returning " << index_map[F[i].bd[j]] << std::endl; + bd_in_slice_filtration.push_back(index_map[simplices[i].boundary()[j]]); + } + std::sort(bd_in_slice_filtration.begin(), bd_in_slice_filtration.end()); + phat_matrix.set_col(i, bd_in_slice_filtration); + } + phat::persistence_pairs phat_persistence_pairs; + phat::compute_persistence_pairs(phat_persistence_pairs, phat_matrix); + + dgm.clear(); + constexpr Real real_inf = std::numeric_limits::infinity(); + for(long i = 0; i < (long) phat_persistence_pairs.get_num_pairs(); i++) { + std::pair new_pair = phat_persistence_pairs.get_pair(i); + bool is_finite_pair = new_pair.second != phat::k_infinity_index; + Real birth = simplices.at(new_pair.first).value(); + Real death = is_finite_pair ? simplices.at(new_pair.second).value() : real_inf; + int dim = simplices[new_pair.first].dim(); + assert(dim + 1 == simplices[new_pair.second].dim()); + if (birth != death) { + dgm.add_point(dim, birth, death); + } + } + + spdlog::debug("Exiting slice_diagram, #dgm[0] = {}", dgm.get_diagram(0).size()); + + return dgm.get_diagram(dim); + } + + template + Box Bifiltration::bounding_box() const + { + return bounding_box_; + } + + template + Real Bifiltration::minimal_coordinate() const + { + return std::min(bounding_box_.lower_left().x, bounding_box_.lower_left().y); + } + + template + void Bifiltration::translate(Real a) + { + bounding_box_.translate(a); + for(auto& simplex : simplices_) { + simplex.translate(a); + } + } + + template + Real Bifiltration::max_x() const + { + if (simplices_.empty()) + return 1; + auto me = std::max_element(simplices_.cbegin(), simplices_.cend(), + [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; }); + assert(me != simplices_.cend()); + return me->position().x; + } + + template + Real Bifiltration::max_y() const + { + if (simplices_.empty()) + return 1; + auto me = std::max_element(simplices_.cbegin(), simplices_.cend(), + [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; }); + assert(me != simplices_.cend()); + return me->position().y; + } + + template + Real Bifiltration::min_x() const + { + if (simplices_.empty()) + return 0; + auto me = std::min_element(simplices_.cbegin(), simplices_.cend(), + [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; }); + assert(me != simplices_.cend()); + return me->position().x; + } + + template + Real Bifiltration::min_y() const + { + if (simplices_.empty()) + return 0; + auto me = std::min_element(simplices_.cbegin(), simplices_.cend(), + [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; }); + assert(me != simplices_.cend()); + return me->position().y; + } + + template + void Bifiltration::add_simplex(Index _id, Point birth, int _dim, const Column& _bdry) + { + simplices_.emplace_back(_id, birth, _dim, _bdry); + } + + template + void Bifiltration::save(const std::string& filename, md::BifiltrationFormat format) + { + switch(format) { + case BifiltrationFormat::rivet: + throw std::runtime_error("Not implemented"); + break; + case BifiltrationFormat::phat_like: { + std::ofstream f(filename); + if (not f.good()) { + std::cerr << "Bifiltration::save: cannot open file " << filename << std::endl; + throw std::runtime_error("Cannot open file for writing "); + } + f << simplices_.size() << "\n"; + + for(const auto& s : simplices_) { + f << s.dim() << " " << s.position().x << " " << s.position().y << " "; + for(int b : s.boundary()) { + f << b << " "; + } + f << std::endl; + } + + } + break; + } + } + + template + void Bifiltration::postprocess_rivet_format() + { + std::map facets_to_ids; + + // fill the map + for(Index i = 0; i < (Index) simplices_.size(); ++i) { + assert(simplices_[i].id() == i); + facets_to_ids[simplices_[i].vertices_] = i; + } + +// for(const auto& s : simplices_) { +// facets_to_ids[s] = s.id(); +// } + + // main loop + for(auto& s : simplices_) { + assert(not s.vertices_.empty()); + assert(s.facet_indices_.empty()); + Column facet_indices; + for(Index i = 0; i <= s.dim(); ++i) { + Column facet; + for(Index j : s.vertices_) { + if (j != i) + facet.push_back(j); + } + auto facet_index = facets_to_ids.at(facet); + facet_indices.push_back(facet_index); + } + s.facet_indices_ = facet_indices; + } // loop over simplices + } + + template + std::ostream& operator<<(std::ostream& os, const Bifiltration& bif) + { + os << "Bifiltration [" << std::endl; + for(const auto& s : bif.simplices()) { + os << s << std::endl; + } + os << "]" << std::endl; + return os; + } + + template + BifiltrationProxy::BifiltrationProxy(const Bifiltration& bif, int dim) + : + dim_(dim), + bif_(bif) + { + cache_positions(); + } + + template + void BifiltrationProxy::cache_positions() const + { + cached_positions_.clear(); + for(const auto& simplex : bif_.simplices()) { + if (simplex.dim() == dim_ or simplex.dim() == dim_ + 1) + cached_positions_.push_back(simplex.position()); + } + } + + template + PointVec + BifiltrationProxy::positions() const + { + if (cached_positions_.empty()) { + cache_positions(); + } + return cached_positions_; + } + + // translate all points by vector (a,a) + template + void BifiltrationProxy::translate(Real a) + { + bif_.translate(a); + } + + // return minimal value of x- and y-coordinates + // among all simplices + template + Real BifiltrationProxy::minimal_coordinate() const + { + return bif_.minimal_coordinate(); + } + + // return box that contains positions of all simplices + template + Box BifiltrationProxy::bounding_box() const + { + return bif_.bounding_box(); + } + + template + Real BifiltrationProxy::max_x() const + { + return bif_.max_x(); + } + + template + Real BifiltrationProxy::max_y() const + { + return bif_.max_y(); + } + + template + Real BifiltrationProxy::min_x() const + { + return bif_.min_x(); + } + + template + Real BifiltrationProxy::min_y() const + { + return bif_.min_y(); + } + + + template + Diagram BifiltrationProxy::weighted_slice_diagram(const DualPoint& slice) const + { + return bif_.weighted_slice_diagram(slice, dim_); + } + +} + diff --git a/matching/include/box.h b/matching/include/box.h index 2990fba..4243667 100644 --- a/matching/include/box.h +++ b/matching/include/box.h @@ -8,20 +8,23 @@ namespace md { + template struct Box { + public: + using Real = Real_; private: - Point ll; - Point ur; + Point ll; + Point ur; public: - Box(Point ll = Point(), Point ur = Point()) + Box(Point ll = Point(), Point ur = Point()) :ll(ll), ur(ur) { } - Box(Point center, Real width, Real height) : - ll(Point(center.x - 0.5 * width, center.y - 0.5 * height)), - ur(Point(center.x + 0.5 * width, center.y + 0.5 * height)) + Box(Point center, Real width, Real height) : + ll(Point(center.x - 0.5 * width, center.y - 0.5 * height)), + ur(Point(center.x + 0.5 * width, center.y + 0.5 * height)) { } @@ -30,11 +33,9 @@ namespace md { inline double height() const { return ur.y - ll.y; } - inline Point lower_left() const { return ll; } - inline Point upper_right() const { return ur; } - inline Point center() const { return Point((ll.x + ur.x) / 2, (ll.y + ur.y) / 2); } - -// bool inside(Point& p) const { return ll.x <= p.x && ll.y <= p.y && ur.x >= p.x && ur.y >= p.y; } + inline Point lower_left() const { return ll; } + inline Point upper_right() const { return ur; } + inline Point center() const { return Point((ll.x + ur.x) / 2, (ll.y + ur.y) / 2); } inline bool operator==(const Box& p) { @@ -43,58 +44,16 @@ namespace md { std::vector refine() const; - std::vector corners() const; + std::vector> corners() const; void translate(Real a); - - // return minimal and maximal value of func - // on the corners of the box - template - std::pair min_max_on_corners(const F& func) const; - - friend std::ostream& operator<<(std::ostream& os, const Box& box); }; - std::ostream& operator<<(std::ostream& os, const Box& box); -// template -// Box compute_bounding_box(InputIterator simplices_begin, InputIterator simplices_end) -// { -// if (simplices_begin == simplices_end) { -// return Box(); -// } -// Box bb; -// bb.ll = bb.ur = simplices_begin->pos; -// for (InputIterator it = simplices_begin; it != simplices_end; it++) { -// Point& pos = it->pos; -// if (pos.x < bb.ll.x) { -// bb.ll.x = pos.x; -// } -// if (pos.y < bb.ll.y) { -// bb.ll.y = pos.y; -// } -// if (pos.x > bb.ur.x) { -// bb.ur.x = pos.x; -// } -// if (pos.y > bb.ur.y) { -// bb.ur.y = pos.y; -// } -// } -// return bb; -// } - - Box get_enclosing_box(const Box& box_a, const Box& box_b); - - template - std::pair Box::min_max_on_corners(const F& func) const - { - std::pair min_max { std::numeric_limits::max(), -std::numeric_limits::max() }; - for(Point p : corners()) { - Real value = func(p); - min_max.first = std::min(min_max.first, value); - min_max.second = std::max(min_max.second, value); - } - return min_max; - }; + template + std::ostream& operator<<(std::ostream& os, const Box& box); + } // namespace md +#include "box.hpp" + #endif //MATCHING_DISTANCE_BOX_H diff --git a/matching/include/box.hpp b/matching/include/box.hpp new file mode 100644 index 0000000..f551d84 --- /dev/null +++ b/matching/include/box.hpp @@ -0,0 +1,52 @@ +namespace md { + + template + std::ostream& operator<<(std::ostream& os, const Box& box) + { + os << "Box(lower_left = " << box.lower_left() << ", upper_right = " << box.upper_right() << ")"; + return os; + } + + template + void Box::translate(Real a) + { + ll.x += a; + ll.y += a; + ur.x += a; + ur.y += a; + } + + template + std::vector> Box::refine() const + { + std::vector> result; + +// 1 | 2 +// 0 | 3 + + Point new_ll = lower_left(); + Point new_ur = center(); + result.emplace_back(new_ll, new_ur); + + new_ll.y = center().y; + new_ur.y = ur.y; + result.emplace_back(new_ll, new_ur); + + new_ll = center(); + new_ur = upper_right(); + result.emplace_back(new_ll, new_ur); + + new_ll.y = ll.y; + new_ur.y = center().y; + result.emplace_back(new_ll, new_ur); + + return result; + } + + template + std::vector> Box::corners() const + { + return {ll, Point(ll.x, ur.y), ur, Point(ur.x, ll.y)}; + }; + +} diff --git a/matching/include/cell_with_value.h b/matching/include/cell_with_value.h index 25644d1..3548a11 100644 --- a/matching/include/cell_with_value.h +++ b/matching/include/cell_with_value.h @@ -1,7 +1,3 @@ -// -// Created by narn on 16.07.19. -// - #ifndef MATCHING_DISTANCE_CELL_WITH_VALUE_H #define MATCHING_DISTANCE_CELL_WITH_VALUE_H @@ -21,7 +17,29 @@ namespace md { upper_right }; - std::ostream& operator<<(std::ostream& os, const ValuePoint& vp); + inline std::ostream& operator<<(std::ostream& os, const ValuePoint& vp) + { + switch(vp) { + case ValuePoint::upper_left : + os << "upper_left"; + break; + case ValuePoint::upper_right : + os << "upper_right"; + break; + case ValuePoint::lower_left : + os << "lower_left"; + break; + case ValuePoint::lower_right : + os << "lower_right"; + break; + case ValuePoint::center: + os << "center"; + break; + default: + os << "FORGOTTEN ValuePoint"; + } + return os; + } const std::vector k_all_vps = {ValuePoint::center, ValuePoint::lower_left, ValuePoint::upper_left, ValuePoint::upper_right, ValuePoint::lower_right}; @@ -31,8 +49,10 @@ namespace md { // represents a cell in the dual space with the value // of the weighted bottleneck distance + template class CellWithValue { public: + using Real = Real_; CellWithValue() = default; @@ -44,18 +64,18 @@ namespace md { CellWithValue& operator=(CellWithValue&& other) = default; - CellWithValue(const DualBox& b, int level) + CellWithValue(const DualBox& b, int level) :dual_box_(b), level_(level) { } - DualBox dual_box() const { return dual_box_; } + DualBox dual_box() const { return dual_box_; } - DualPoint center() const { return dual_box_.center(); } + DualPoint center() const { return dual_box_.center(); } Real value_at(ValuePoint vp) const; bool has_value_at(ValuePoint vp) const; - DualPoint value_point(ValuePoint vp) const; + DualPoint value_point(ValuePoint vp) const; int level() const { return level_; } @@ -73,8 +93,6 @@ namespace md { std::vector get_refined_cells() const; - friend std::ostream& operator<<(std::ostream&, const CellWithValue&); - void set_max_possible_value(Real new_upper_bound); int num_values() const; @@ -100,7 +118,7 @@ namespace md { bool has_upper_right_value() const { return upper_right_value_ >= 0; } - DualBox dual_box_; + DualBox dual_box_; Real central_value_ {-1.0}; Real lower_left_value_ {-1.0}; Real lower_right_value_ {-1.0}; @@ -114,7 +132,10 @@ namespace md { bool has_max_possible_value_ {false}; }; - std::ostream& operator<<(std::ostream& os, const CellWithValue& cell); + template + std::ostream& operator<<(std::ostream& os, const CellWithValue& cell); } // namespace md +#include "cell_with_value.hpp" + #endif //MATCHING_DISTANCE_CELL_WITH_VALUE_H diff --git a/matching/include/cell_with_value.hpp b/matching/include/cell_with_value.hpp new file mode 100644 index 0000000..88b2569 --- /dev/null +++ b/matching/include/cell_with_value.hpp @@ -0,0 +1,224 @@ +namespace md { + +#ifdef MD_DEBUG + long long int CellWithValue::max_id = 0; +#endif + + template + Real CellWithValue::value_at(ValuePoint vp) const + { + switch(vp) { + case ValuePoint::upper_left : + return upper_left_value_; + case ValuePoint::upper_right : + return upper_right_value_; + case ValuePoint::lower_left : + return lower_left_value_; + case ValuePoint::lower_right : + return lower_right_value_; + case ValuePoint::center: + return central_value_; + } + // to shut up compiler warning + return 1.0 / 0.0; + } + + template + bool CellWithValue::has_value_at(ValuePoint vp) const + { + switch(vp) { + case ValuePoint::upper_left : + return upper_left_value_ >= 0; + case ValuePoint::upper_right : + return upper_right_value_ >= 0; + case ValuePoint::lower_left : + return lower_left_value_ >= 0; + case ValuePoint::lower_right : + return lower_right_value_ >= 0; + case ValuePoint::center: + return central_value_ >= 0; + } + // to shut up compiler warning + return 1.0 / 0.0; + } + + template + DualPoint CellWithValue::value_point(md::ValuePoint vp) const + { + switch(vp) { + case ValuePoint::upper_left : + return dual_box().upper_left(); + case ValuePoint::upper_right : + return dual_box().upper_right(); + case ValuePoint::lower_left : + return dual_box().lower_left(); + case ValuePoint::lower_right : + return dual_box().lower_right(); + case ValuePoint::center: + return dual_box().center(); + } + // to shut up compiler warning + return DualPoint(); + } + + template + bool CellWithValue::has_corner_value() const + { + return has_lower_left_value() or has_lower_right_value() or has_upper_left_value() + or has_upper_right_value(); + } + + template + Real CellWithValue::stored_upper_bound() const + { + assert(has_max_possible_value_); + return max_possible_value_; + } + + template + Real CellWithValue::max_corner_value() const + { + return std::max({lower_left_value_, lower_right_value_, upper_left_value_, upper_right_value_}); + } + + template + Real CellWithValue::min_value() const + { + Real result = std::numeric_limits::max(); + for(auto vp : k_all_vps) { + if (not has_value_at(vp)) { + continue; + } + result = std::min(result, value_at(vp)); + } + return result; + } + + template + std::vector> CellWithValue::get_refined_cells() const + { + std::vector> result; + result.reserve(4); + for(const auto& refined_box : dual_box_.refine()) { + + CellWithValue refined_cell(refined_box, level() + 1); + +#ifdef MD_DEBUG + refined_cell.parent_ids = parent_ids; + refined_cell.parent_ids.push_back(id); + refined_cell.id = ++max_id; +#endif + + if (refined_box.lower_left() == dual_box_.lower_left()) { + // _|_ + // H|_ + + refined_cell.set_value_at(ValuePoint::lower_left, lower_left_value_); + refined_cell.set_value_at(ValuePoint::upper_right, central_value_); + + } else if (refined_box.upper_right() == dual_box_.upper_right()) { + // _|H + // _|_ + + refined_cell.set_value_at(ValuePoint::lower_left, central_value_); + refined_cell.set_value_at(ValuePoint::upper_right, upper_right_value_); + + } else if (refined_box.lower_right() == dual_box_.lower_right()) { + // _|_ + // _|H + + refined_cell.set_value_at(ValuePoint::lower_right, lower_right_value_); + refined_cell.set_value_at(ValuePoint::upper_left, central_value_); + + } else if (refined_box.upper_left() == dual_box_.upper_left()) { + + // H|_ + // _|_ + + refined_cell.set_value_at(ValuePoint::lower_right, central_value_); + refined_cell.set_value_at(ValuePoint::upper_left, upper_left_value_); + } + result.emplace_back(refined_cell); + } + return result; + } + + template + void CellWithValue::set_value_at(ValuePoint vp, Real new_value) + { + if (has_value_at(vp)) + spd::error("CellWithValue: trying to re-assign value!, this = {}, vp = {}", *this, vp); + + switch(vp) { + case ValuePoint::upper_left : + upper_left_value_ = new_value; + break; + case ValuePoint::upper_right : + upper_right_value_ = new_value; + break; + case ValuePoint::lower_left : + lower_left_value_ = new_value; + break; + case ValuePoint::lower_right : + lower_right_value_ = new_value; + break; + case ValuePoint::center: + central_value_ = new_value; + break; + } + } + + template + int CellWithValue::num_values() const + { + int result = 0; + for(ValuePoint vp : k_all_vps) { + result += has_value_at(vp); + } + return result; + } + + + template + void CellWithValue::set_max_possible_value(Real new_upper_bound) + { + assert(new_upper_bound >= central_value_); + assert(new_upper_bound >= lower_left_value_); + assert(new_upper_bound >= lower_right_value_); + assert(new_upper_bound >= upper_left_value_); + assert(new_upper_bound >= upper_right_value_); + has_max_possible_value_ = true; + max_possible_value_ = new_upper_bound; + } + + + template + std::ostream& operator<<(std::ostream& os, const CellWithValue& cell) + { + os << "CellWithValue(box = " << cell.dual_box() << ", "; + +#ifdef MD_DEBUG + os << "id = " << cell.id; + if (not cell.parent_ids.empty()) + os << ", parent_ids = " << container_to_string(cell.parent_ids) << ", "; +#endif + + for(ValuePoint vp : k_all_vps) { + if (cell.has_value_at(vp)) { + os << "value = " << cell.value_at(vp); + os << ", at " << vp << " " << cell.value_point(vp); + } + } + + os << ", max_corner_value = "; + if (cell.has_max_possible_value()) { + os << cell.stored_upper_bound(); + } else { + os << "-"; + } + + os << ", level = " << cell.level() << ")"; + return os; + } + +} // namespace md diff --git a/matching/include/common_defs.h b/matching/include/common_defs.h index 3f3d937..8d01325 100644 --- a/matching/include/common_defs.h +++ b/matching/include/common_defs.h @@ -1,8 +1,8 @@ #ifndef MATCHING_DISTANCE_DEF_DEBUG_H #define MATCHING_DISTANCE_DEF_DEBUG_H -//#define EXPERIMENTAL_TIMING -//#define PRINT_HEAT_MAP +//#define MD_EXPERIMENTAL_TIMING +//#define MD_PRINT_HEAT_MAP //#define MD_DEBUG //#define MD_DO_CHECKS //#define MD_DO_FULL_CHECK diff --git a/matching/include/common_util.h b/matching/include/common_util.h index 2d8dcb0..778536f 100644 --- a/matching/include/common_util.h +++ b/matching/include/common_util.h @@ -11,22 +11,24 @@ #include #include +#include +#include + +namespace spd = spdlog; + #include "common_defs.h" #include "phat/helpers/misc.h" - namespace md { - - using Real = double; - using RealVec = std::vector; using Index = phat::index; using IndexVec = std::vector; - static constexpr Real pi = M_PI; + //static constexpr Real pi = M_PI; using Column = std::vector; + template struct Point { Real x; Real y; @@ -71,59 +73,56 @@ namespace md { }; - using PointVec = std::vector; - - Point operator+(const Point& u, const Point& v); + template + using PointVec = std::vector>; - Point operator-(const Point& u, const Point& v); + template + Point operator+(const Point& u, const Point& v); - Point least_upper_bound(const Point& u, const Point& v); + template + Point operator-(const Point& u, const Point& v); - Point greatest_lower_bound(const Point& u, const Point& v); - Point max_point(); + template + Point least_upper_bound(const Point& u, const Point& v); - Point min_point(); + template + Point greatest_lower_bound(const Point& u, const Point& v); - std::ostream& operator<<(std::ostream& ostr, const Point& vec); + template + Point max_point(); - Real L_infty(const Point& v); + template + Point min_point(); - Real l_2_norm(const Point& v); + template + std::ostream& operator<<(std::ostream& ostr, const Point& vec); - Real l_2_dist(const Point& x, const Point& y); + template + using DiagramPoint = std::pair; - Real l_infty_dist(const Point& x, const Point& y); + template + using Diagram = std::vector>; - using Interval = std::pair; - - // return minimal interval that contains both a and b - inline Interval minimal_covering_interval(Interval a, Interval b) - { - return {std::min(a.first, b.first), std::max(a.second, b.second)}; - } // to keep diagrams in all dimensions // TODO: store in Hera format? + template class DiagramKeeper { public: - using DiagramPoint = std::pair; - using Diagram = std::vector; DiagramKeeper() { }; void add_point(int dim, Real birth, Real death); - Diagram get_diagram(int dim) const; + Diagram get_diagram(int dim) const; void clear() { data_.clear(); } private: - std::map data_; + std::map> data_; }; - using Diagram = std::vector>; - template std::string container_to_string(const C& cont) { @@ -140,42 +139,18 @@ namespace md { return ss.str(); } - int gcd(int a, int b); - - struct Rational { - int numerator {0}; - int denominator {1}; - Rational() = default; - Rational(int n, int d) : numerator(n / gcd(n, d)), denominator(d / gcd(n, d)) {} - Rational(std::pair p) : Rational(p.first, p.second) {} - Rational(int n) : numerator(n), denominator(1) {} - Real to_real() const { return (Real)numerator / (Real)denominator; } - void reduce(); - Rational& operator+=(const Rational& rhs); - Rational& operator-=(const Rational& rhs); - Rational& operator*=(const Rational& rhs); - Rational& operator/=(const Rational& rhs); - }; - - using namespace std::rel_ops; - - bool operator==(const Rational& a, const Rational& b); - bool operator<(const Rational& a, const Rational& b); - std::ostream& operator<<(std::ostream& os, const Rational& a); - - // arithmetic - Rational operator+(Rational a, const Rational& b); - Rational operator-(Rational a, const Rational& b); - Rational operator*(Rational a, const Rational& b); - Rational operator/(Rational a, const Rational& b); - - Rational reduce(Rational frac); - - Rational midpoint(Rational a, Rational b); - // return true, if s is empty or starts with # (commented out line) // whitespaces in the beginning of s are ignored - bool ignore_line(const std::string& s); + inline bool ignore_line(const std::string& s) + { + for(auto c : s) { + if (isspace(c)) + continue; + return (c == '#'); + } + return true; + } + // split string by delimeter template @@ -195,10 +170,10 @@ namespace md { } namespace std { - template<> - struct hash + template + struct hash> { - std::size_t operator()(const md::Point& p) const + std::size_t operator()(const md::Point& p) const { auto hx = std::hash()(p.x); auto hy = std::hash()(p.y); @@ -207,5 +182,7 @@ namespace std { }; }; +#include "common_util.hpp" + #endif //MATCHING_DISTANCE_COMMON_UTIL_H diff --git a/matching/include/common_util.hpp b/matching/include/common_util.hpp new file mode 100644 index 0000000..76d97af --- /dev/null +++ b/matching/include/common_util.hpp @@ -0,0 +1,96 @@ +#include +#include +#include +#include +#include +#include + +#include + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +namespace md { + + template + Point operator+(const Point& u, const Point& v) + { + return Point(u.x + v.x, u.y + v.y); + } + + template + Point operator-(const Point& u, const Point& v) + { + return Point(u.x - v.x, u.y - v.y); + } + + template + Point least_upper_bound(const Point& u, const Point& v) + { + return Point(std::max(u.x, v.x), std::max(u.y, v.y)); + } + + template + Point greatest_lower_bound(const Point& u, const Point& v) + { + return Point(std::min(u.x, v.x), std::min(u.y, v.y)); + } + + template + Point max_point() + { + return Point(std::numeric_limits::max(), std::numeric_limits::min()); + } + + template + Point min_point() + { + return Point(-std::numeric_limits::max(), -std::numeric_limits::min()); + } + + template + std::ostream& operator<<(std::ostream& ostr, const Point& vec) + { + ostr << "(" << vec.x << ", " << vec.y << ")"; + return ostr; + } + + template + Real l_infty_norm(const Point& v) + { + return std::max(std::abs(v.x), std::abs(v.y)); + } + + template + Real l_2_norm(const Point& v) + { + return v.norm(); + } + + template + Real l_2_dist(const Point& x, const Point& y) + { + return l_2_norm(x - y); + } + + template + Real l_infty_dist(const Point& x, const Point& y) + { + return l_infty_norm(x - y); + } + + template + void DiagramKeeper::add_point(int dim, Real birth, Real death) + { + data_[dim].emplace_back(birth, death); + } + + template + Diagram DiagramKeeper::get_diagram(int dim) const + { + if (data_.count(dim) == 1) + return data_.at(dim); + else + return Diagram(); + } +} diff --git a/matching/include/dual_box.h b/matching/include/dual_box.h index ce0384d..0e4f4d5 100644 --- a/matching/include/dual_box.h +++ b/matching/include/dual_box.h @@ -4,16 +4,23 @@ #include #include #include +#include + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + #include "common_util.h" #include "dual_point.h" namespace md { + + template class DualBox { public: - DualBox(DualPoint ll, DualPoint ur); + DualBox(DualPoint ll, DualPoint ur); DualBox() = default; DualBox(const DualBox&) = default; @@ -23,12 +30,12 @@ namespace md { DualBox& operator=(DualBox&& other) = default; - DualPoint center() const { return midpoint(lower_left_, upper_right_); } - DualPoint lower_left() const { return lower_left_; } - DualPoint upper_right() const { return upper_right_; } + DualPoint center() const { return midpoint(lower_left_, upper_right_); } + DualPoint lower_left() const { return lower_left_; } + DualPoint upper_right() const { return upper_right_; } - DualPoint lower_right() const; - DualPoint upper_left() const; + DualPoint lower_right() const; + DualPoint upper_left() const; AxisType axis_type() const { return lower_left_.axis_type(); } AngleType angle_type() const { return lower_left_.angle_type(); } @@ -42,66 +49,35 @@ namespace md { bool is_flat() const { return upper_right_.is_flat(); } bool is_steep() const { return lower_left_.is_steep(); } - // return minimal and maximal value of func - // on the corners of the box - template - std::pair min_max_on_corners(const F& func) const; - - template - Real max_abs_value(const F& func) const; - - std::vector refine() const; - std::vector corners() const; - std::vector critical_points(const Point& p) const; + std::vector> corners() const; + std::vector> critical_points(const Point& p) const; // sample n points from the box uniformly; for tests - std::vector random_points(int n) const; + std::vector> random_points(int n) const; // return 2 dual points at the boundary // where push changes from horizontal to vertical - std::vector push_change_points(const Point& p) const; - - friend std::ostream& operator<<(std::ostream& os, const DualBox& db); + std::vector> push_change_points(const Point& p) const; // check that a has same sign, angles are all flat or all steep bool sanity_check() const; - bool contains(const DualPoint& dp) const; + bool contains(const DualPoint& dp) const; bool operator==(const DualBox& other) const; private: - DualPoint lower_left_; - DualPoint upper_right_; + DualPoint lower_left_; + DualPoint upper_right_; }; - std::ostream& operator<<(std::ostream& os, const DualBox& db); - - template - std::pair DualBox::min_max_on_corners(const F& func) const + template + std::ostream& operator<<(std::ostream& os, const DualBox& db) { - std::pair min_max { std::numeric_limits::max(), -std::numeric_limits::max() }; - for(auto p : corners()) { - Real value = func(p); - min_max.first = std::min(min_max.first, value); - min_max.second = std::max(min_max.second, value); - } - return min_max; - }; - - - template - Real DualBox::max_abs_value(const F& func) const - { - Real result = 0; - for(auto p_1 : corners()) { - for(auto p_2 : corners()) { - Real value = fabs(func(p_1, p_2)); - result = std::max(value, result); - } - } - return result; - }; - + os << "DualBox(" << db.lower_left() << ", " << db.upper_right() << ")"; + return os; + } } +#include "dual_box.hpp" + #endif //MATCHING_DISTANCE_DUAL_BOX_H diff --git a/matching/include/dual_box.hpp b/matching/include/dual_box.hpp new file mode 100644 index 0000000..85f7f27 --- /dev/null +++ b/matching/include/dual_box.hpp @@ -0,0 +1,190 @@ +namespace md { + + template + DualBox::DualBox(DualPoint ll, DualPoint ur) + :lower_left_(ll), upper_right_(ur) + { + } + + template + std::vector> DualBox::corners() const + { + return {lower_left_, + DualPoint(axis_type(), angle_type(), lower_left_.lambda(), upper_right_.mu()), + upper_right_, + DualPoint(axis_type(), angle_type(), upper_right_.lambda(), lower_left_.mu())}; + } + + template + std::vector> DualBox::push_change_points(const Point& p) const + { + std::vector> result; + result.reserve(2); + + bool is_y_type = lower_left_.is_y_type(); + bool is_flat = lower_left_.is_flat(); + + auto mu_from_lambda = [p, is_y_type, is_flat](Real lambda) { + bool is_x_type = not is_y_type, is_steep = not is_flat; + if (is_y_type && is_flat) { + return p.y - lambda * p.x; + } else if (is_y_type && is_steep) { + return p.y - p.x / lambda; + } else if (is_x_type && is_flat) { + return p.x - p.y / lambda; + } else if (is_x_type && is_steep) { + return p.x - lambda * p.y; + } + // to shut up compiler warning + return static_cast(1.0 / 0.0); + }; + + auto lambda_from_mu = [p, is_y_type, is_flat](Real mu) { + bool is_x_type = not is_y_type, is_steep = not is_flat; + if (is_y_type && is_flat) { + return (p.y - mu) / p.x; + } else if (is_y_type && is_steep) { + return p.x / (p.y - mu); + } else if (is_x_type && is_flat) { + return p.y / (p.x - mu); + } else if (is_x_type && is_steep) { + return (p.x - mu) / p.y; + } + // to shut up compiler warning + return static_cast(1.0 / 0.0); + }; + + // all inequalities below are strict: equality means it is a corner + // && critical_points() returns corners anyway + + Real mu_intersect_min = mu_from_lambda(lambda_min()); + + if (mu_min() < mu_intersect_min && mu_intersect_min < mu_max()) + result.emplace_back(axis_type(), angle_type(), lambda_min(), mu_intersect_min); + + Real mu_intersect_max = mu_from_lambda(lambda_max()); + + if (mu_max() < mu_intersect_max && mu_intersect_max < mu_max()) + result.emplace_back(axis_type(), angle_type(), lambda_max(), mu_intersect_max); + + Real lambda_intersect_min = lambda_from_mu(mu_min()); + + if (lambda_min() < lambda_intersect_min && lambda_intersect_min < lambda_max()) + result.emplace_back(axis_type(), angle_type(), lambda_intersect_min, mu_min()); + + Real lambda_intersect_max = lambda_from_mu(mu_max()); + if (lambda_min() < lambda_intersect_max && lambda_intersect_max < lambda_max()) + result.emplace_back(axis_type(), angle_type(), lambda_intersect_max, mu_max()); + + assert(result.size() <= 2); + + if (result.size() > 2) { + fmt::print("Error in push_change_points, p = {}, dual_box = {}, result = {}\n", p, *this, + container_to_string(result)); + throw std::runtime_error("push_change_points returned more than 2 points"); + } + + return result; + } + + template + std::vector> DualBox::critical_points(const Point& /*p*/) const + { + // maximal difference is attained at corners + return corners(); +// std::vector> result; +// result.reserve(6); +// for(auto dp : corners()) result.push_back(dp); +// for(auto dp : push_change_points(p)) result.push_back(dp); +// return result; + } + + template + std::vector> DualBox::random_points(int n) const + { + assert(n >= 0); + std::mt19937_64 gen(1); + std::vector> result; + result.reserve(n); + std::uniform_real_distribution mu_distr(mu_min(), mu_max()); + std::uniform_real_distribution lambda_distr(lambda_min(), lambda_max()); + for(int i = 0; i < n; ++i) { + result.emplace_back(axis_type(), angle_type(), lambda_distr(gen), mu_distr(gen)); + } + return result; + } + + template + bool DualBox::sanity_check() const + { + lower_left_.sanity_check(); + upper_right_.sanity_check(); + + if (lower_left_.angle_type() != upper_right_.angle_type()) + throw std::runtime_error("angle types differ"); + + if (lower_left_.axis_type() != upper_right_.axis_type()) + throw std::runtime_error("axis types differ"); + + if (lower_left_.lambda() >= upper_right_.lambda()) + throw std::runtime_error("lambda of lower_left_ greater than lambda of upper_right "); + + if (lower_left_.mu() >= upper_right_.mu()) + throw std::runtime_error("mu of lower_left_ greater than mu of upper_right "); + + return true; + } + + template + std::vector> DualBox::refine() const + { + std::vector> result; + + result.reserve(4); + + Real lambda_middle = (lower_left().lambda() + upper_right().lambda()) / 2.0; + Real mu_middle = (lower_left().mu() + upper_right().mu()) / 2.0; + + DualPoint refinement_center(axis_type(), angle_type(), lambda_middle, mu_middle); + + result.emplace_back(lower_left_, refinement_center); + + result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_middle, mu_min()), + DualPoint(axis_type(), angle_type(), lambda_max(), mu_middle)); + + result.emplace_back(refinement_center, upper_right_); + + result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_min(), mu_middle), + DualPoint(axis_type(), angle_type(), lambda_middle, mu_max())); + return result; + } + + template + bool DualBox::operator==(const DualBox& other) const + { + return lower_left() == other.lower_left() && + upper_right() == other.upper_right(); + } + + template + bool DualBox::contains(const DualPoint& dp) const + { + return dp.angle_type() == angle_type() && dp.axis_type() == axis_type() && + mu_max() >= dp.mu() && + mu_min() <= dp.mu() && + lambda_min() <= dp.lambda() && + lambda_max() >= dp.lambda(); + } + + template + DualPoint DualBox::lower_right() const + { + return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_max(), mu_min()); + } + + template + DualPoint DualBox::upper_left() const + { + return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_min(), mu_max()); + } +} diff --git a/matching/include/dual_point.h b/matching/include/dual_point.h index db32f1a..8438860 100644 --- a/matching/include/dual_point.h +++ b/matching/include/dual_point.h @@ -1,12 +1,9 @@ -// -// Created by narn on 12.02.19. -// - #ifndef MATCHING_DISTANCE_DUAL_POINT_H #define MATCHING_DISTANCE_DUAL_POINT_H #include #include +#include #include "common_util.h" #include "box.h" @@ -25,9 +22,10 @@ namespace md { // so, e.g., line y = x has 4 different non-equal representation. // we are unlikely to ever need this, because 4 cases are // always treated separately. + template class DualPoint { public: - using Real = md::Real; + using Real = Real_; DualPoint() = default; @@ -56,7 +54,6 @@ namespace md { bool is_y_type() const { return axis_type_ == AxisType::y_type; } - friend std::ostream& operator<<(std::ostream& os, const DualPoint& dp); bool operator<(const DualPoint& rhs) const; AxisType axis_type() const { return axis_type_; } @@ -66,16 +63,16 @@ namespace md { // return true otherwise bool sanity_check() const; - Real weighted_push(Point p) const; - Point push(Point p) const; + Real weighted_push(Point p) const; + Point push(Point p) const; bool is_horizontal() const; bool is_vertical() const; - bool goes_below(Point p) const; - bool goes_above(Point p) const; + bool goes_below(Point p) const; + bool goes_above(Point p) const; - bool contains(Point p) const; + bool contains(Point p) const; Real x_slope() const; Real y_slope() const; @@ -98,9 +95,13 @@ namespace md { Real mu_ {-1.0}; }; - std::ostream& operator<<(std::ostream& os, const DualPoint& dp); + template + std::ostream& operator<<(std::ostream& os, const DualPoint& dp); - DualPoint midpoint(DualPoint x, DualPoint y); + template + DualPoint midpoint(DualPoint x, DualPoint y); }; +#include "dual_point.hpp" + #endif //MATCHING_DISTANCE_DUAL_POINT_H diff --git a/matching/include/dual_point.hpp b/matching/include/dual_point.hpp new file mode 100644 index 0000000..04e25f2 --- /dev/null +++ b/matching/include/dual_point.hpp @@ -0,0 +1,299 @@ +namespace md { + + inline std::ostream& operator<<(std::ostream& os, const AxisType& at) + { + if (at == AxisType::x_type) + os << "x-type"; + else + os << "y-type"; + return os; + } + + inline std::ostream& operator<<(std::ostream& os, const AngleType& at) + { + if (at == AngleType::flat) + os << "flat"; + else + os << "steep"; + return os; + } + + template + std::ostream& operator<<(std::ostream& os, const DualPoint& dp) + { + os << "Line(" << dp.axis_type() << ", "; + os << dp.angle_type() << ", "; + os << dp.lambda() << ", "; + os << dp.mu() << ", equation: "; + if (not dp.is_vertical()) { + os << "y = " << dp.y_slope() << " x + " << dp.y_intercept(); + } else { + os << "x = " << dp.x_intercept(); + } + os << " )"; + return os; + } + + template + bool DualPoint::operator<(const DualPoint& rhs) const + { + return std::tie(axis_type_, angle_type_, lambda_, mu_) + < std::tie(rhs.axis_type_, rhs.angle_type_, rhs.lambda_, rhs.mu_); + } + + template + DualPoint::DualPoint(AxisType axis_type, AngleType angle_type, Real lambda, Real mu) + : + axis_type_(axis_type), + angle_type_(angle_type), + lambda_(lambda), + mu_(mu) + { + assert(sanity_check()); + } + + template + bool DualPoint::sanity_check() const + { + if (lambda_ < 0.0) + throw std::runtime_error("Invalid line, negative lambda"); + if (lambda_ > 1.0) + throw std::runtime_error("Invalid line, lambda > 1"); + if (mu_ < 0.0) + throw std::runtime_error("Invalid line, negative mu"); + return true; + } + + template + Real DualPoint::gamma() const + { + if (is_steep()) + return atan(Real(1.0) / lambda_); + else + return atan(lambda_); + } + + template + DualPoint midpoint(DualPoint x, DualPoint y) + { + assert(x.angle_type() == y.angle_type() and x.axis_type() == y.axis_type()); + Real lambda_mid = (x.lambda() + y.lambda()) / 2; + Real mu_mid = (x.mu() + y.mu()) / 2; + return DualPoint(x.axis_type(), x.angle_type(), lambda_mid, mu_mid); + + } + + // return k in the line equation y = kx + b + template + Real DualPoint::y_slope() const + { + if (is_flat()) + return lambda(); + else + return Real(1.0) / lambda(); + } + + // return k in the line equation x = ky + b + template + Real DualPoint::x_slope() const + { + if (is_flat()) + return Real(1.0) / lambda(); + else + return lambda(); + } + + // return b in the line equation y = kx + b + template + Real DualPoint::y_intercept() const + { + if (is_y_type()) { + return mu(); + } else { + // x = x_slope * y + mu = x_slope * (y + mu / x_slope) + // x-intercept is -mu/x_slope = -mu * y_slope + return -mu() * y_slope(); + } + } + + // return k in the line equation x = ky + b + template + Real DualPoint::x_intercept() const + { + if (is_x_type()) { + return mu(); + } else { + // y = y_slope * x + mu = y_slope (x + mu / y_slope) + // x_intercept is -mu/y_slope = -mu * x_slope + return -mu() * x_slope(); + } + } + + template + Real DualPoint::x_from_y(Real y) const + { + if (is_horizontal()) + throw std::runtime_error("x_from_y called on horizontal line"); + else + return x_slope() * y + x_intercept(); + } + + template + Real DualPoint::y_from_x(Real x) const + { + if (is_vertical()) + throw std::runtime_error("x_from_y called on horizontal line"); + else + return y_slope() * x + y_intercept(); + } + + template + bool DualPoint::is_horizontal() const + { + return is_flat() and lambda() == 0; + } + + template + bool DualPoint::is_vertical() const + { + return is_steep() and lambda() == 0; + } + + template + bool DualPoint::contains(Point p) const + { + if (is_vertical()) + return p.x == x_from_y(p.y); + else + return p.y == y_from_x(p.x); + } + + template + bool DualPoint::goes_below(Point p) const + { + if (is_vertical()) + return p.x <= x_from_y(p.y); + else + return p.y >= y_from_x(p.x); + } + + template + bool DualPoint::goes_above(Point p) const + { + if (is_vertical()) + return p.x >= x_from_y(p.y); + else + return p.y <= y_from_x(p.x); + } + + template + Point DualPoint::push(Point p) const + { + Point result; + // if line is below p, we push horizontally + bool horizontal_push = goes_below(p); + if (is_x_type()) { + if (is_flat()) { + if (horizontal_push) { + result.x = p.y / lambda() + mu(); + result.y = p.y; + } else { + // vertical push + result.x = p.x; + result.y = lambda() * (p.x - mu()); + } + } else { + // steep + if (horizontal_push) { + result.x = lambda() * p.y + mu(); + result.y = p.y; + } else { + // vertical push + result.x = p.x; + result.y = (p.x - mu()) / lambda(); + } + } + } else { + // y-type + if (is_flat()) { + if (horizontal_push) { + result.x = (p.y - mu()) / lambda(); + result.y = p.y; + } else { + // vertical push + result.x = p.x; + result.y = lambda() * p.x + mu(); + } + } else { + // steep + if (horizontal_push) { + result.x = (p.y - mu()) * lambda(); + result.y = p.y; + } else { + // vertical push + result.x = p.x; + result.y = p.x / lambda() + mu(); + } + } + } + return result; + } + + template + Real DualPoint::weighted_push(Point p) const + { + // if line is below p, we push horizontally + bool horizontal_push = goes_below(p); + if (is_x_type()) { + if (is_flat()) { + if (horizontal_push) { + return p.y; + } else { + // vertical push + return lambda() * (p.x - mu()); + } + } else { + // steep + if (horizontal_push) { + return lambda() * p.y; + } else { + // vertical push + return (p.x - mu()); + } + } + } else { + // y-type + if (is_flat()) { + if (horizontal_push) { + return p.y - mu(); + } else { + // vertical push + return lambda() * p.x; + } + } else { + // steep + if (horizontal_push) { + return lambda() * (p.y - mu()); + } else { + // vertical push + return p.x; + } + } + } + } + + template + bool DualPoint::operator==(const DualPoint& other) const + { + return axis_type() == other.axis_type() and + angle_type() == other.angle_type() and + mu() == other.mu() and + lambda() == other.lambda(); + } + + template + Real DualPoint::weight() const + { + return lambda_ / sqrt(1 + lambda_ * lambda_); + } +} // namespace md diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h index bb10203..5be34c7 100644 --- a/matching/include/matching_distance.h +++ b/matching/include/matching_distance.h @@ -4,9 +4,10 @@ #include #include #include +#include +#include +#include -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" #include "common_defs.h" #include "cell_with_value.h" @@ -17,12 +18,15 @@ #include "bifiltration.h" #include "bottleneck.h" -namespace spd = spdlog; - namespace md { - using HeatMap = std::map; - using HeatMaps = std::map; +#ifdef MD_PRINT_HEAT_MAP + template + using HeatMap = std::map, Real>; + + template + using HeatMaps = std::map>; +#endif enum class BoundStrategy { bruteforce, @@ -39,18 +43,107 @@ namespace md { upper_bound }; - std::ostream& operator<<(std::ostream& os, const BoundStrategy& s); - - std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s); - - std::istream& operator>>(std::istream& is, BoundStrategy& s); - - std::istream& operator>>(std::istream& is, TraverseStrategy& s); - - BoundStrategy bs_from_string(std::string s); - - TraverseStrategy ts_from_string(std::string s); - + inline std::ostream& operator<<(std::ostream& os, const BoundStrategy& s) + { + switch(s) { + case BoundStrategy::bruteforce : + os << "bruteforce"; + break; + case BoundStrategy::local_dual_bound : + os << "local_grob"; + break; + case BoundStrategy::local_combined : + os << "local_combined"; + break; + case BoundStrategy::local_dual_bound_refined : + os << "local_refined"; + break; + case BoundStrategy::local_dual_bound_for_each_point : + os << "local_for_each_point"; + break; + default: + os << "FORGOTTEN BOUND STRATEGY"; + } + return os; + } + + inline std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s) + { + switch(s) { + case TraverseStrategy::depth_first : + os << "DFS"; + break; + case TraverseStrategy::breadth_first : + os << "BFS"; + break; + case TraverseStrategy::breadth_first_value : + os << "BFS-VAL"; + break; + case TraverseStrategy::upper_bound : + os << "UB"; + break; + default: + os << "FORGOTTEN TRAVERSE STRATEGY"; + } + return os; + } + + inline std::istream& operator>>(std::istream& is, TraverseStrategy& s) + { + std::string ss; + is >> ss; + if (ss == "DFS") { + s = TraverseStrategy::depth_first; + } else if (ss == "BFS") { + s = TraverseStrategy::breadth_first; + } else if (ss == "BFS-VAL") { + s = TraverseStrategy::breadth_first_value; + } else if (ss == "UB") { + s = TraverseStrategy::upper_bound; + } else { + throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY"); + } + return is; + } + + + inline std::istream& operator>>(std::istream& is, BoundStrategy& s) + { + std::string ss; + is >> ss; + if (ss == "bruteforce") { + s = BoundStrategy::bruteforce; + } else if (ss == "local_grob") { + s = BoundStrategy::local_dual_bound; + } else if (ss == "local_combined") { + s = BoundStrategy::local_combined; + } else if (ss == "local_refined") { + s = BoundStrategy::local_dual_bound_refined; + } else if (ss == "local_for_each_point") { + s = BoundStrategy::local_dual_bound_for_each_point; + } else { + throw std::runtime_error("UNKNOWN BOUND STRATEGY"); + } + return is; + } + + inline BoundStrategy bs_from_string(std::string s) + { + std::stringstream ss(s); + BoundStrategy result; + ss >> result; + return result; + } + + inline TraverseStrategy ts_from_string(std::string s) + { + std::stringstream ss(s); + TraverseStrategy result; + ss >> result; + return result; + } + + template struct CalculationParams { static constexpr int ALL_DIMENSIONS = -1; @@ -75,22 +168,22 @@ namespace md { // print statistics on each quad-tree level bool print_stats { false }; -#ifdef PRINT_HEAT_MAP +#ifdef MD_PRINT_HEAT_MAP HeatMaps heat_maps; #endif }; - template + template class DistanceCalculator { - using DualBox = md::DualBox; - using CellValueVector = std::vector; + using Real = Real_; + using CellValueVector = std::vector>; public: DistanceCalculator(const DiagramProvider& a, const DiagramProvider& b, - CalculationParams& params); + CalculationParams& params); Real distance(); @@ -100,7 +193,7 @@ namespace md { DiagramProvider module_a_; DiagramProvider module_b_; - CalculationParams& params_; + CalculationParams& params_; int n_hera_calls_; std::map n_hera_calls_per_level_; @@ -112,65 +205,83 @@ namespace md { CellValueVector get_initial_dual_grid(Real& lower_bound); +#ifdef MD_PRINT_HEAT_MAP void heatmap_in_dimension(int dim, int depth); +#endif Real get_max_x(int module) const; Real get_max_y(int module) const; - void set_cell_central_value(CellWithValue& dual_cell); + void set_cell_central_value(CellWithValue& dual_cell); Real get_distance(); Real get_distance_pq(); - // temporary, to try priority queue - Real get_max_possible_value(const CellWithValue* first_cell_ptr, int n_cells); + Real get_max_possible_value(const CellWithValue* first_cell_ptr, int n_cells); - Real get_upper_bound(const CellWithValue& dual_cell, Real good_enough_upper_bound) const; + Real get_upper_bound(const CellWithValue& dual_cell, Real good_enough_upper_bound) const; - Real get_single_dgm_bound(const CellWithValue& dual_cell, ValuePoint vp, int module, + Real get_single_dgm_bound(const CellWithValue& dual_cell, ValuePoint vp, int module, Real good_enough_value) const; // this bound depends only on dual box - Real get_local_dual_bound(int module, const DualBox& dual_box) const; + Real get_local_dual_bound(int module, const DualBox& dual_box) const; - Real get_local_dual_bound(const DualBox& dual_box) const; + Real get_local_dual_bound(const DualBox& dual_box) const; // this bound depends only on dual box, is more accurate - Real get_local_refined_bound(int module, const md::DualBox& dual_box) const; + Real get_local_refined_bound(int module, const DualBox& dual_box) const; - Real get_local_refined_bound(const md::DualBox& dual_box) const; + Real get_local_refined_bound(const DualBox& dual_box) const; Real get_good_enough_upper_bound(Real lower_bound) const; - Real - get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint value_point, const Point& p) const; + Real get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint value_point, + const Point& p) const; - void check_upper_bound(const CellWithValue& dual_cell) const; + void check_upper_bound(const CellWithValue& dual_cell) const; - Real distance_on_line(DualPoint line); - Real distance_on_line_const(DualPoint line) const; + Real distance_on_line(DualPoint line); + Real distance_on_line_const(DualPoint line) const; Real current_error(Real lower_bound, Real upper_bound); }; - Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, CalculationParams& params); + template + Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, + CalculationParams& params); - Real matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b, CalculationParams& params); + template + Real matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b, + CalculationParams& params); // for upper bound experiment struct UbExperimentRecord { - Real error; - Real lower_bound; - Real upper_bound; - CellWithValue cell; + double error; + double lower_bound; + double upper_bound; + CellWithValue cell; long long int time; long long int n_hera_calls; }; - std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r); + inline std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r) + { + os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound; + return os; + } + + + template + void print_map(const std::map& dic) + { + for(const auto kv : dic) { + fmt::print("{} -> {}\n", kv.first, kv.second); + } + } -} +} // namespace md #include "matching_distance.hpp" diff --git a/matching/include/matching_distance.hpp b/matching/include/matching_distance.hpp index d2d2fbc..48c8464 100644 --- a/matching/include/matching_distance.hpp +++ b/matching/include/matching_distance.hpp @@ -1,34 +1,26 @@ namespace md { - template - void print_map(const std::map& dic) - { - for(const auto kv : dic) { - fmt::print("{} -> {}\n", kv.first, kv.second); - } - } - - template - void DistanceCalculator::check_upper_bound(const CellWithValue& dual_cell) const + template + void DistanceCalculator::check_upper_bound(const CellWithValue& dual_cell) const { spd::debug("Enter check_get_max_delta_on_cell"); const int n_samples_lambda = 100; const int n_samples_mu = 100; - DualBox db = dual_cell.dual_box(); - Real min_lambda = db.lambda_min(); - Real max_lambda = db.lambda_max(); - Real min_mu = db.mu_min(); - Real max_mu = db.mu_max(); - - Real h_lambda = (max_lambda - min_lambda) / n_samples_lambda; - Real h_mu = (max_mu - min_mu) / n_samples_mu; + DualBox db = dual_cell.dual_box(); + R min_lambda = db.lambda_min(); + R max_lambda = db.lambda_max(); + R min_mu = db.mu_min(); + R max_mu = db.mu_max(); + + R h_lambda = (max_lambda - min_lambda) / n_samples_lambda; + R h_mu = (max_mu - min_mu) / n_samples_mu; for(int i = 1; i < n_samples_lambda; ++i) { for(int j = 1; j < n_samples_mu; ++j) { - Real lambda = min_lambda + i * h_lambda; - Real mu = min_mu + j * h_mu; - DualPoint l(db.axis_type(), db.angle_type(), lambda, mu); - Real other_result = distance_on_line_const(l); - Real diff = fabs(dual_cell.stored_upper_bound() - other_result); + R lambda = min_lambda + i * h_lambda; + R mu = min_mu + j * h_mu; + DualPoint l(db.axis_type(), db.angle_type(), lambda, mu); + R other_result = distance_on_line_const(l); + R diff = fabs(dual_cell.stored_upper_bound() - other_result); if (other_result > dual_cell.stored_upper_bound()) { spd::error( "in check_upper_bound, upper_bound = {}, other_result = {}, diff = {}\ndual_cell = {}", @@ -42,10 +34,10 @@ namespace md { // for all lines l, l' inside dual box, // find the upper bound on the difference of weighted pushes of p - template - Real - DistanceCalculator::get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint vp, - const Point& p) const + template + R + DistanceCalculator::get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint vp, + const Point& p) const { assert(p.x >= 0 && p.y >= 0); @@ -53,15 +45,15 @@ namespace md { std::vector debug_ids = {3, 13, 54, 218, 350, 382, 484, 795, 2040, 8415, 44076}; bool debug = false; // std::find(debug_ids.begin(), debug_ids.end(), dual_cell.id) != debug_ids.end(); #endif - DualPoint line = dual_cell.value_point(vp); - const Real base_value = line.weighted_push(p); + DualPoint line = dual_cell.value_point(vp); + const R base_value = line.weighted_push(p); spd::debug("Enter get_max_displacement_single_point, p = {},\ndual_cell = {},\nline = {}, base_value = {}\n", p, dual_cell, line, base_value); - Real result = 0.0; - for(DualPoint dp : dual_cell.dual_box().critical_points(p)) { - Real dp_value = dp.weighted_push(p); + R result = 0.0; + for(DualPoint dp : dual_cell.dual_box().critical_points(p)) { + R dp_value = dp.weighted_push(p); spd::debug( "In get_max_displacement_single_point, p = {}, critical dp = {},\ndp_value = {}, diff = {},\ndual_cell = {}\n", p, dp, dp_value, fabs(base_value - dp_value), dual_cell); @@ -69,15 +61,15 @@ namespace md { } #ifdef MD_DO_FULL_CHECK - DualBox db = dual_cell.dual_box(); - std::uniform_real_distribution dlambda(db.lambda_min(), db.lambda_max()); - std::uniform_real_distribution dmu(db.mu_min(), db.mu_max()); + auto db = dual_cell.dual_box(); + std::uniform_real_distribution dlambda(db.lambda_min(), db.lambda_max()); + std::uniform_real_distribution dmu(db.mu_min(), db.mu_max()); std::mt19937 gen(1); for(int i = 0; i < 1000; ++i) { - Real lambda = dlambda(gen); - Real mu = dmu(gen); - DualPoint dp_random { db.axis_type(), db.angle_type(), lambda, mu }; - Real dp_value = dp_random.weighted_push(p); + R lambda = dlambda(gen); + R mu = dmu(gen); + DualPoint dp_random { db.axis_type(), db.angle_type(), lambda, mu }; + R dp_value = dp_random.weighted_push(p); if (fabs(base_value - dp_value) > result) { spd::error("in get_max_displacement_single_point, p = {}, vp = {}\ndb = {}\nresult = {}, base_value = {}, dp_value = {}, dp_random = {}", p, vp, db, result, base_value, dp_value, dp_random); @@ -89,12 +81,12 @@ namespace md { return result; } - template - typename DistanceCalculator::CellValueVector DistanceCalculator::get_initial_dual_grid(Real& lower_bound) + template + typename DistanceCalculator::CellValueVector DistanceCalculator::get_initial_dual_grid(R& lower_bound) { CellValueVector result = get_refined_grid(params_.initialization_depth, false, true); - lower_bound = -1.0; + lower_bound = -1; for(const auto& dc : result) { lower_bound = std::max(lower_bound, dc.max_corner_value()); } @@ -102,8 +94,8 @@ namespace md { assert(lower_bound >= 0); for(auto& dual_cell : result) { - Real good_enough_ub = get_good_enough_upper_bound(lower_bound); - Real max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub); + R good_enough_ub = get_good_enough_upper_bound(lower_bound); + R max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub); dual_cell.set_max_possible_value(max_value_on_cell); #ifdef MD_DO_FULL_CHECK @@ -116,39 +108,39 @@ namespace md { return result; } - template - typename DistanceCalculator::CellValueVector - DistanceCalculator::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last) + template + typename DistanceCalculator::CellValueVector + DistanceCalculator::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last) { - const Real y_max = std::max(module_a_.max_y(), module_b_.max_y()); - const Real x_max = std::max(module_a_.max_x(), module_b_.max_x()); + const R y_max = std::max(module_a_.max_y(), module_b_.max_y()); + const R x_max = std::max(module_a_.max_x(), module_b_.max_x()); - const Real lambda_min = 0; - const Real lambda_max = 1; + const R lambda_min = 0; + const R lambda_max = 1; - const Real mu_min = 0; + const R mu_min = 0; - DualBox x_flat(DualPoint(AxisType::x_type, AngleType::flat, lambda_min, mu_min), - DualPoint(AxisType::x_type, AngleType::flat, lambda_max, x_max)); + DualBox x_flat(DualPoint(AxisType::x_type, AngleType::flat, lambda_min, mu_min), + DualPoint(AxisType::x_type, AngleType::flat, lambda_max, x_max)); - DualBox x_steep(DualPoint(AxisType::x_type, AngleType::steep, lambda_min, mu_min), - DualPoint(AxisType::x_type, AngleType::steep, lambda_max, x_max)); + DualBox x_steep(DualPoint(AxisType::x_type, AngleType::steep, lambda_min, mu_min), + DualPoint(AxisType::x_type, AngleType::steep, lambda_max, x_max)); - DualBox y_flat(DualPoint(AxisType::y_type, AngleType::flat, lambda_min, mu_min), - DualPoint(AxisType::y_type, AngleType::flat, lambda_max, y_max)); + DualBox y_flat(DualPoint(AxisType::y_type, AngleType::flat, lambda_min, mu_min), + DualPoint(AxisType::y_type, AngleType::flat, lambda_max, y_max)); - DualBox y_steep(DualPoint(AxisType::y_type, AngleType::steep, lambda_min, mu_min), - DualPoint(AxisType::y_type, AngleType::steep, lambda_max, y_max)); + DualBox y_steep(DualPoint(AxisType::y_type, AngleType::steep, lambda_min, mu_min), + DualPoint(AxisType::y_type, AngleType::steep, lambda_max, y_max)); - CellWithValue x_flat_cell(x_flat, 0); - CellWithValue x_steep_cell(x_steep, 0); - CellWithValue y_flat_cell(y_flat, 0); - CellWithValue y_steep_cell(y_steep, 0); + CellWithValue x_flat_cell(x_flat, 0); + CellWithValue x_steep_cell(x_steep, 0); + CellWithValue y_flat_cell(y_flat, 0); + CellWithValue y_steep_cell(y_steep, 0); if (init_depth == 0) { - DualPoint diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0); + DualPoint diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0); - Real diagonal_value = distance_on_line(diagonal_x_flat); + R diagonal_value = distance_on_line(diagonal_x_flat); n_hera_calls_per_level_[0]++; x_flat_cell.set_value_at(ValuePoint::lower_right, diagonal_value); @@ -162,7 +154,7 @@ namespace md { x_steep_cell.id = 2; y_flat_cell.id = 3; y_steep_cell.id = 4; - CellWithValue::max_id = 4; + CellWithValue::max_id = 4; #endif CellValueVector result {x_flat_cell, x_steep_cell, y_flat_cell, y_steep_cell}; @@ -189,10 +181,10 @@ namespace md { return result; } - template - DistanceCalculator::DistanceCalculator(const T& a, + template + DistanceCalculator::DistanceCalculator(const T& a, const T& b, - CalculationParams& params) + CalculationParams& params) : module_a_(a), module_b_(b), @@ -213,33 +205,33 @@ namespace md { module_a_.max_x(), module_a_.max_y(), module_b_.max_x(), module_b_.max_y()); } - template - Real DistanceCalculator::get_max_x(int module) const + template + R DistanceCalculator::get_max_x(int module) const { return (module == 0) ? module_a_.max_x() : module_b_.max_x(); } - template - Real DistanceCalculator::get_max_y(int module) const + template + R DistanceCalculator::get_max_y(int module) const { return (module == 0) ? module_a_.max_y() : module_b_.max_y(); } - template - Real - DistanceCalculator::get_local_refined_bound(const md::DualBox& dual_box) const + template + R + DistanceCalculator::get_local_refined_bound(const DualBox& dual_box) const { return get_local_refined_bound(0, dual_box) + get_local_refined_bound(1, dual_box); } - template - Real - DistanceCalculator::get_local_refined_bound(int module, const md::DualBox& dual_box) const + template + R + DistanceCalculator::get_local_refined_bound(int module, const DualBox& dual_box) const { spd::debug("Enter get_local_refined_bound, dual_box = {}", dual_box); - Real d_lambda = dual_box.lambda_max() - dual_box.lambda_min(); - Real d_mu = dual_box.mu_max() - dual_box.mu_min(); - Real result; + R d_lambda = dual_box.lambda_max() - dual_box.lambda_min(); + R d_mu = dual_box.mu_max() - dual_box.mu_min(); + R result; if (dual_box.axis_type() == AxisType::x_type) { if (dual_box.is_flat()) { result = dual_box.lambda_max() * d_mu + (get_max_x(module) - dual_box.mu_min()) * d_lambda; @@ -258,11 +250,11 @@ namespace md { return result; } - template - Real DistanceCalculator::get_local_dual_bound(int module, const md::DualBox& dual_box) const + template + R DistanceCalculator::get_local_dual_bound(int module, const DualBox& dual_box) const { - Real dlambda = dual_box.lambda_max() - dual_box.lambda_min(); - Real dmu = dual_box.mu_max() - dual_box.mu_min(); + R dlambda = dual_box.lambda_max() - dual_box.lambda_min(); + R dmu = dual_box.mu_max() - dual_box.mu_min(); if (dual_box.is_flat()) { return get_max_x(module) * dlambda + dmu; @@ -271,20 +263,20 @@ namespace md { } } - template - Real DistanceCalculator::get_local_dual_bound(const md::DualBox& dual_box) const + template + R DistanceCalculator::get_local_dual_bound(const DualBox& dual_box) const { return get_local_dual_bound(0, dual_box) + get_local_dual_bound(1, dual_box); } - template - Real DistanceCalculator::get_upper_bound(const CellWithValue& dual_cell, Real good_enough_ub) const + template + R DistanceCalculator::get_upper_bound(const CellWithValue& dual_cell, R good_enough_ub) const { assert(good_enough_ub >= 0); switch(params_.bound_strategy) { case BoundStrategy::bruteforce: - return std::numeric_limits::max(); + return std::numeric_limits::max(); case BoundStrategy::local_dual_bound: return dual_cell.min_value() + get_local_dual_bound(dual_cell.dual_box()); @@ -293,7 +285,7 @@ namespace md { return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); case BoundStrategy::local_combined: { - Real cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); + R cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); if (cheap_upper_bound < good_enough_ub) { return cheap_upper_bound; } else { @@ -302,14 +294,14 @@ namespace md { } case BoundStrategy::local_dual_bound_for_each_point: { - Real result = std::numeric_limits::max(); + R result = std::numeric_limits::max(); for(ValuePoint vp : k_corner_vps) { if (not dual_cell.has_value_at(vp)) { continue; } - Real base_value = dual_cell.value_at(vp); - Real bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub); + R base_value = dual_cell.value_at(vp); + R bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub); if (params_.stop_asap and bound_dgm_a + base_value >= good_enough_ub) { // we want to return a valid upper bound, not just something that will prevent discarding the cell @@ -318,8 +310,8 @@ namespace md { return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); } - Real bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1, - std::max(Real(0), good_enough_ub - bound_dgm_a)); + R bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1, + std::max(R(0), good_enough_ub - bound_dgm_a)); result = std::min(result, base_value + bound_dgm_a + bound_dgm_b); @@ -336,19 +328,19 @@ namespace md { } } // to suppress compiler warning - return std::numeric_limits::max(); + return std::numeric_limits::max(); } // find maximal displacement of weighted points of m for all lines in dual_box - template - Real - DistanceCalculator::get_single_dgm_bound(const CellWithValue& dual_cell, + template + R + DistanceCalculator::get_single_dgm_bound(const CellWithValue& dual_cell, ValuePoint vp, int module, - [[maybe_unused]] Real good_enough_value) const + R good_enough_value) const { - Real result = 0; - Point max_point; + R result = 0; + Point max_point; spd::debug( "Enter get_single_dgm_bound, module = {}, dual_cell = {}, vp = {}, good_enough_value = {}, stop_asap = {}\n", @@ -358,7 +350,7 @@ namespace md { for(const auto& position : m.positions()) { spd::debug("in get_single_dgm_bound, simplex = {}\n", position); - Real x = get_max_displacement_single_point(dual_cell, vp, position); + R x = get_max_displacement_single_point(dual_cell, vp, position); spd::debug("In get_single_dgm_bound, point = {}, displacement = {}", position, x); @@ -385,30 +377,30 @@ namespace md { return result; } - template - Real DistanceCalculator::distance() + template + R DistanceCalculator::distance() { return get_distance_pq(); } // calculate weighted bottleneneck distance between slices on line // increments hera calls counter - template - Real DistanceCalculator::distance_on_line(DualPoint line) + template + R DistanceCalculator::distance_on_line(DualPoint line) { ++n_hera_calls_; - Real result = distance_on_line_const(line); + R result = distance_on_line_const(line); return result; } - template - Real DistanceCalculator::distance_on_line_const(DualPoint line) const + template + R DistanceCalculator::distance_on_line_const(DualPoint line) const { // TODO: think about this - how to call Hera auto dgm_a = module_a_.weighted_slice_diagram(line); auto dgm_b = module_b_.weighted_slice_diagram(line); - Real result; - if (params_.hera_epsilon > static_cast(0)) { + R result; + if (params_.hera_epsilon > static_cast(0)) { result = hera::bottleneckDistApprox(dgm_a, dgm_b, params_.hera_epsilon) / ( params_.hera_epsilon + 1); } else { result = hera::bottleneckDistExact(dgm_a, dgm_b); @@ -423,10 +415,10 @@ namespace md { return result; } - template - Real DistanceCalculator::get_good_enough_upper_bound(Real lower_bound) const + template + R DistanceCalculator::get_good_enough_upper_bound(R lower_bound) const { - Real result; + R result; // in upper_bound strategy we only prune cells if they cannot improve the lower bound, // otherwise the experiment is supposed to run indefinitely if (params_.traverse_strategy == TraverseStrategy::upper_bound) { @@ -440,14 +432,14 @@ namespace md { // helper function // calculate weighted bt distance on cell center, // assign distance value to cell, keep it in heat_map, and return - template - void DistanceCalculator::set_cell_central_value(CellWithValue& dual_cell) + template + void DistanceCalculator::set_cell_central_value(CellWithValue& dual_cell) { - DualPoint central_line {dual_cell.center()}; + DualPoint central_line {dual_cell.center()}; spd::debug("In set_cell_central_value, processing dual cell = {}, line = {}", dual_cell.dual_box(), central_line); - Real new_value = distance_on_line(central_line); + R new_value = distance_on_line(central_line); n_hera_calls_per_level_[dual_cell.level() + 1]++; dual_cell.set_value_at(ValuePoint::center, new_value); params_.actual_max_depth = std::max(params_.actual_max_depth, dual_cell.level() + 1); @@ -472,10 +464,10 @@ namespace md { // assumes that the underlying container is vector! // cell_ptr: pointer to the first element in queue // n_cells: queue size - template - Real DistanceCalculator::get_max_possible_value(const CellWithValue* cell_ptr, int n_cells) + template + R DistanceCalculator::get_max_possible_value(const CellWithValue* cell_ptr, int n_cells) { - Real result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0; + R result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0; for(int i = 0; i < n_cells; ++i, ++cell_ptr) { result = std::max(result, cell_ptr->stored_upper_bound()); } @@ -485,11 +477,11 @@ namespace md { // helper function: // return current error from lower and upper bounds // and save it in params_ (hence not const) - template - Real DistanceCalculator::current_error(Real lower_bound, Real upper_bound) + template + R DistanceCalculator::current_error(R lower_bound, R upper_bound) { - Real current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound - : std::numeric_limits::max(); + R current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound + : std::numeric_limits::max(); params_.actual_error = current_error; @@ -505,8 +497,8 @@ namespace md { // use priority queue to store dual cells // comparison function depends on the strategies in params_ // ressets hera calls counter - template - Real DistanceCalculator::get_distance_pq() + template + R DistanceCalculator::get_distance_pq() { std::map n_cells_considered; std::map n_cells_pushed_into_queue; @@ -527,26 +519,26 @@ namespace md { // if cell is too deep and is not pushed into queue, // we still need to take its max value into account; // the max over such cells is stored in max_result_on_too_fine_cells - Real upper_bound_on_deep_cells = -1; + R upper_bound_on_deep_cells = -1; spd::debug("Started iterations in dual space, delta = {}, bound_strategy = {}", params_.delta, params_.bound_strategy); // user-defined less lambda function // to regulate priority queue depending on strategy - auto dual_cell_less = [this](const CellWithValue& a, const CellWithValue& b) { + auto dual_cell_less = [this](const CellWithValue& a, const CellWithValue& b) { int a_level = a.level(); int b_level = b.level(); - Real a_value = a.max_corner_value(); - Real b_value = b.max_corner_value(); - Real a_ub = a.stored_upper_bound(); - Real b_ub = b.stored_upper_bound(); + R a_value = a.max_corner_value(); + R b_value = b.max_corner_value(); + R a_ub = a.stored_upper_bound(); + R b_ub = b.stored_upper_bound(); if (this->params_.traverse_strategy == TraverseStrategy::upper_bound and (not a.has_max_possible_value() or not b.has_max_possible_value())) { throw std::runtime_error("no upper bound on cell"); } - DualPoint a_lower_left = a.dual_box().lower_left(); - DualPoint b_lower_left = b.dual_box().lower_left(); + DualPoint a_lower_left = a.dual_box().lower_left(); + DualPoint b_lower_left = b.dual_box().lower_left(); switch(this->params_.traverse_strategy) { // in both breadth_first searches we want coarser cells @@ -569,24 +561,24 @@ namespace md { } }; - std::priority_queue dual_cells_queue( + std::priority_queue, CellValueVector, decltype(dual_cell_less)> dual_cells_queue( dual_cell_less); // weighted bt distance on the center of current cell - Real lower_bound = std::numeric_limits::min(); + R lower_bound = std::numeric_limits::min(); // init pq and lower bound for(auto& init_cell : get_initial_dual_grid(lower_bound)) { dual_cells_queue.push(init_cell); } - Real upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size()); + R upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size()); std::vector ub_experiment_results; while(not dual_cells_queue.empty()) { - CellWithValue dual_cell = dual_cells_queue.top(); + CellWithValue dual_cell = dual_cells_queue.top(); dual_cells_queue.pop(); assert(dual_cell.has_corner_value() and dual_cell.has_max_possible_value() @@ -620,7 +612,7 @@ namespace md { // until now, dual_cell knows its value in one of its corners // new_value will be the weighted distance at its center set_cell_central_value(dual_cell); - Real new_value = dual_cell.value_at(ValuePoint::center); + R new_value = dual_cell.value_at(ValuePoint::center); lower_bound = std::max(new_value, lower_bound); spd::debug("Processed cell = {}, weighted value = {}, lower_bound = {}", dual_cell, new_value, lower_bound); @@ -638,11 +630,11 @@ namespace md { throw std::runtime_error("no value on cell"); // if delta is smaller than good_enough_value, it allows to prune cell - Real good_enough_ub = get_good_enough_upper_bound(lower_bound); + R good_enough_ub = get_good_enough_upper_bound(lower_bound); // upper bound of the parent holds for refined_cell // and can sometimes be smaller! - Real upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(), + R upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(), get_upper_bound(refined_cell, good_enough_ub)); spd::debug("upper_bound_on_refined_cell = {}, dual_cell.stored_upper_bound = {}, get_upper_bound = {}", @@ -774,10 +766,46 @@ namespace md { return lower_bound; } - template - int DistanceCalculator::get_hera_calls_number() const + template + int DistanceCalculator::get_hera_calls_number() const { return n_hera_calls_; } -} \ No newline at end of file + template + R matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, + CalculationParams& params) + { + R result; + // compute distance only in one dimension + if (params.dim != CalculationParams::ALL_DIMENSIONS) { + BifiltrationProxy bifp_a(bif_a, params.dim); + BifiltrationProxy bifp_b(bif_b, params.dim); + DistanceCalculator> runner(bifp_a, bifp_b, params); + result = runner.distance(); + params.n_hera_calls = runner.get_hera_calls_number(); + } else { + // compute distance in all dimensions, return maximal + result = -1; + for(int dim = 0; dim < std::max(bif_a.maximal_dim(), bif_b.maximal_dim()); ++dim) { + BifiltrationProxy bifp_a(bif_a, params.dim); + BifiltrationProxy bifp_b(bif_a, params.dim); + DistanceCalculator> runner(bifp_a, bifp_b, params); + result = std::max(result, runner.distance()); + params.n_hera_calls += runner.get_hera_calls_number(); + } + } + return result; + } + + + template + R matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b, + CalculationParams& params) + { + DistanceCalculator> runner(mod_a, mod_b, params); + R result = runner.distance(); + params.n_hera_calls = runner.get_hera_calls_number(); + return result; + } +} // namespace md diff --git a/matching/include/persistence_module.h b/matching/include/persistence_module.h index a1fc67e..e99771f 100644 --- a/matching/include/persistence_module.h +++ b/matching/include/persistence_module.h @@ -5,6 +5,12 @@ #include #include #include +#include +#include +#include + +#include "phat/boundary_matrix.h" +#include "phat/compute_persistence_pairs.h" #include "common_util.h" #include "dual_point.h" @@ -28,17 +34,20 @@ namespace md { */ + template class ModulePresentation { public: + using RealVec = std::vector; + enum Format { rivet_firep }; struct Relation { - Point position_; + Point position_; IndexVec components_; Relation() {} - Relation(const Point& _pos, const IndexVec& _components); + Relation(const Point& _pos, const IndexVec& _components); Real get_x() const { return position_.x; } Real get_y() const { return position_.y; } @@ -48,9 +57,9 @@ namespace md { ModulePresentation() {} - ModulePresentation(const PointVec& _generators, const RelVec& _relations); + ModulePresentation(const PointVec& _generators, const RelVec& _relations); - Diagram weighted_slice_diagram(const DualPoint& line) const; + Diagram weighted_slice_diagram(const DualPoint& line) const; // translate all points by vector (a,a) void translate(Real a); @@ -59,9 +68,7 @@ namespace md { Real minimal_coordinate() const { return std::min(min_x(), min_y()); } // return box that contains all positions of all simplices - Box bounding_box() const; - - friend std::ostream& operator<<(std::ostream& os, const ModulePresentation& mp); + Box bounding_box() const; Real max_x() const { return max_x_; } @@ -71,26 +78,27 @@ namespace md { Real min_y() const { return min_y_; } - PointVec positions() const; + PointVec positions() const; private: - PointVec generators_; + PointVec generators_; std::vector relations_; - PointVec positions_; + PointVec positions_; Real max_x_ { std::numeric_limits::max() }; Real max_y_ { std::numeric_limits::max() }; Real min_x_ { -std::numeric_limits::max() }; Real min_y_ { -std::numeric_limits::max() }; - Box bounding_box_; + Box bounding_box_; void init_boundaries(); - void project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; - void project_relations(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; + void project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; + void project_relations(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; }; } // namespace md +#include "persistence_module.hpp" #endif //MATCHING_DISTANCE_PERSISTENCE_MODULE_H diff --git a/matching/include/persistence_module.hpp b/matching/include/persistence_module.hpp new file mode 100644 index 0000000..6e49b2e --- /dev/null +++ b/matching/include/persistence_module.hpp @@ -0,0 +1,177 @@ +namespace md { + + /** + * + * @param values vector of length n + * @return [a_1,...,a_n] such that + * 1) values[a_1] <= values[a_2] <= ... <= values[a_n] + * 2) a_1,...,a_n is a permutation of 1,..,n + */ + + template + IndexVec get_sorted_indices(const std::vector& values) + { + IndexVec result(values.size()); + std::iota(result.begin(), result.end(), 0); + std::sort(result.begin(), result.end(), + [&values](size_t a, size_t b) { return values[a] < values[b]; }); + return result; + } + + // helper function to initialize const member positions_ in ModulePresentation + template + PointVec concat_gen_and_rel_positions(const PointVec& generators, + const typename ModulePresentation::RelVec& relations) + { + std::unordered_set> ps(generators.begin(), generators.end()); + for(const auto& rel : relations) { + ps.insert(rel.position_); + } + return PointVec(ps.begin(), ps.end()); + } + + + template + void ModulePresentation::init_boundaries() + { + max_x_ = std::numeric_limits::max(); + max_y_ = std::numeric_limits::max(); + min_x_ = -std::numeric_limits::max(); + min_y_ = -std::numeric_limits::max(); + + for(const auto& gen : positions_) { + min_x_ = std::min(gen.x, min_x_); + min_y_ = std::min(gen.y, min_y_); + max_x_ = std::max(gen.x, max_x_); + max_y_ = std::max(gen.y, max_y_); + } + + bounding_box_ = Box(Point(min_x_, min_y_), Point(max_x_, max_y_)); + } + + + template + ModulePresentation::ModulePresentation(const PointVec& _generators, const RelVec& _relations) : + generators_(_generators), + relations_(_relations) + { + init_boundaries(); + } + + template + void ModulePresentation::translate(Real a) + { + for(auto& g : generators_) { + g.translate(a); + } + + for(auto& r : relations_) { + r.position_.translate(a); + } + + positions_ = concat_gen_and_rel_positions(generators_, relations_); + init_boundaries(); + } + + + /** + * + * @param slice line on which generators are projected + * @param sorted_indices [a_1,...,a_n] s.t. wpush(generator[a_1]) <= wpush(generator[a_2]) <= .. + * @param projections sorted weighted pushes of generators + */ + + template + void ModulePresentation::project_generators(const DualPoint& slice, + IndexVec& sorted_indices, RealVec& projections) const + { + size_t num_gens = generators_.size(); + + RealVec gen_values; + gen_values.reserve(num_gens); + for(const auto& pos : generators_) { + gen_values.push_back(slice.weighted_push(pos)); + } + sorted_indices = get_sorted_indices(gen_values); + projections.clear(); + projections.reserve(num_gens); + for(auto i : sorted_indices) { + projections.push_back(gen_values[i]); + } + } + + template + void ModulePresentation::project_relations(const DualPoint& slice, IndexVec& sorted_rel_indices, + RealVec& projections) const + { + size_t num_rels = relations_.size(); + + RealVec rel_values; + rel_values.reserve(num_rels); + for(const auto& rel : relations_) { + rel_values.push_back(slice.weighted_push(rel.position_)); + } + sorted_rel_indices = get_sorted_indices(rel_values); + projections.clear(); + projections.reserve(num_rels); + for(auto i : sorted_rel_indices) { + projections.push_back(rel_values[i]); + } + } + + template + Diagram ModulePresentation::weighted_slice_diagram(const DualPoint& slice) const + { + IndexVec sorted_gen_indices, sorted_rel_indices; + RealVec gen_projections, rel_projections; + + project_generators(slice, sorted_gen_indices, gen_projections); + project_relations(slice, sorted_rel_indices, rel_projections); + + phat::boundary_matrix<> phat_matrix; + + phat_matrix.set_num_cols(relations_.size()); + + for(Index i = 0; i < (Index) relations_.size(); i++) { + IndexVec current_relation = relations_[sorted_rel_indices[i]].components_; + for(auto& j : current_relation) { + j = sorted_gen_indices[j]; + } + std::sort(current_relation.begin(), current_relation.end()); + phat_matrix.set_dim(i, current_relation.size()); + phat_matrix.set_col(i, current_relation); + } + + phat::persistence_pairs phat_persistence_pairs; + phat::compute_persistence_pairs(phat_persistence_pairs, phat_matrix); + + Diagram dgm; + + constexpr Real real_inf = std::numeric_limits::infinity(); + + for(Index i = 0; i < (Index) phat_persistence_pairs.get_num_pairs(); i++) { + std::pair new_pair = phat_persistence_pairs.get_pair(i); + bool is_finite_pair = new_pair.second != phat::k_infinity_index; + Real birth = gen_projections.at(new_pair.first); + Real death = is_finite_pair ? rel_projections.at(new_pair.second) : real_inf; + if (birth != death) { + dgm.emplace_back(birth, death); + } + } + + return dgm; + } + + template + PointVec ModulePresentation::positions() const + { + return positions_; + } + + template + Box ModulePresentation::bounding_box() const + { + return bounding_box_; + } + +} // namespace md diff --git a/matching/include/simplex.h b/matching/include/simplex.h index e9d0e30..75bbcae 100644 --- a/matching/include/simplex.h +++ b/matching/include/simplex.h @@ -9,6 +9,7 @@ namespace md { + template class Bifiltration; enum class BifiltrationFormat { @@ -38,11 +39,21 @@ namespace md { int dim() const { return vertices_.size() - 1; } - void push_back(int v); + void push_back(int v) + { + vertices_.push_back(v); + std::sort(vertices_.begin(), vertices_.end()); + } AbstractSimplex() { } - AbstractSimplex(std::vector vertices, bool sort = true); + AbstractSimplex(std::vector vertices, bool sort = true) + :vertices_(vertices) + { + if (sort) + std::sort(vertices_.begin(), vertices_.end()); + } + template AbstractSimplex(Iter beg_iter, Iter end_iter, bool sort = true) @@ -53,22 +64,51 @@ namespace md { std::sort(vertices_.begin(), end()); } - std::vector facets() const; + std::vector facets() const + { + std::vector result; + for (int i = 0; i < static_cast(vertices_.size()); ++i) { + std::vector facet_vertices; + facet_vertices.reserve(dim()); + for (int j = 0; j < static_cast(vertices_.size()); ++j) { + if (j != i) + facet_vertices.push_back(vertices_[j]); + } + if (!facet_vertices.empty()) { + result.emplace_back(facet_vertices, false); + } + } + return result; + } friend std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s); - // compare by vertices_ only friend bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2); friend bool operator<(const AbstractSimplex&, const AbstractSimplex&); }; - std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s); + inline std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s) + { + os << "AbstractSimplex(id = " << s.id << ", vertices_ = " << container_to_string(s.vertices_) << ")"; + return os; + } + + inline bool operator<(const AbstractSimplex& a, const AbstractSimplex& b) + { + return a.vertices_ < b.vertices_; + } + + inline bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2) + { + return s1.vertices_ == s2.vertices_; + } + template class Simplex { private: Index id_; - Point pos_; + Point pos_; int dim_; // in our format we use facet indices, // this is the fastest representation for homology @@ -77,11 +117,11 @@ namespace md { // conversion routines are in Bifiltration Column facet_indices_; Column vertices_; - Real v {0.0}; // used when constructed a filtration for a slice + Real v {0}; // used when constructed a filtration for a slice public: Simplex(Index _id, std::string s, BifiltrationFormat input_format); - Simplex(Index _id, Point birth, int _dim, const Column& _bdry); + Simplex(Index _id, Point birth, int _dim, const Column& _bdry); void init_rivet(std::string s); @@ -96,9 +136,9 @@ namespace md { Real value() const { return v; } // assumes 1-criticality - Point position() const { return pos_; } + Point position() const { return pos_; } - void set_position(const Point& new_pos) { pos_ = new_pos; } + void set_position(const Point& new_pos) { pos_ = new_pos; } void scale(Real lambda) { @@ -110,12 +150,14 @@ namespace md { void set_value(Real new_val) { v = new_val; } - friend std::ostream& operator<<(std::ostream& os, const Simplex& s); - - friend Bifiltration; + friend Bifiltration; }; - std::ostream& operator<<(std::ostream& os, const Simplex& s); + template + std::ostream& operator<<(std::ostream& os, const Simplex& s); } + +#include "simplex.hpp" + #endif //MATCHING_DISTANCE_SIMPLEX_H diff --git a/matching/include/simplex.hpp b/matching/include/simplex.hpp new file mode 100644 index 0000000..ce0e30f --- /dev/null +++ b/matching/include/simplex.hpp @@ -0,0 +1,79 @@ +namespace md { + + template + Simplex::Simplex(Index id, Point birth, int dim, const Column& bdry) + : + id_(id), + pos_(birth), + dim_(dim), + facet_indices_(bdry) { } + + template + void Simplex::translate(Real a) + { + pos_.translate(a); + } + + template + void Simplex::init_rivet(std::string s) + { + auto delim_pos = s.find_first_of(";"); + assert(delim_pos > 0); + std::string vertices_str = s.substr(0, delim_pos); + std::string pos_str = s.substr(delim_pos + 1); + assert(not vertices_str.empty() and not pos_str.empty()); + // get vertices + std::stringstream vertices_ss(vertices_str); + int dim = 0; + int vertex; + while (vertices_ss >> vertex) { + dim++; + vertices_.push_back(vertex); + } + // + std::sort(vertices_.begin(), vertices_.end()); + assert(dim > 0); + + std::stringstream pos_ss(pos_str); + // TODO: get rid of 1-criticaltiy assumption + pos_ss >> pos_.x >> pos_.y; + } + + template + void Simplex::init_phat_like(std::string s) + { + facet_indices_.clear(); + std::stringstream ss(s); + ss >> dim_ >> pos_.x >> pos_.y; + if (dim_ > 0) { + facet_indices_.reserve(dim_ + 1); + for (int j = 0; j <= dim_; j++) { + Index k; + ss >> k; + facet_indices_.push_back(k); + } + } + } + + template + Simplex::Simplex(Index _id, std::string s, BifiltrationFormat input_format) + :id_(_id) + { + switch (input_format) { + case BifiltrationFormat::phat_like : + init_phat_like(s); + break; + case BifiltrationFormat::rivet : + init_rivet(s); + break; + } + } + + template + std::ostream& operator<<(std::ostream& os, const Simplex& x) + { + os << "Simplex(id = " << x.id() << ", dim = " << x.dim(); + os << ", boundary = " << container_to_string(x.boundary()) << ", pos = " << x.position() << ")"; + return os; + } +} diff --git a/matching/src/bifiltration.cpp b/matching/src/bifiltration.cpp deleted file mode 100644 index 44b12cf..0000000 --- a/matching/src/bifiltration.cpp +++ /dev/null @@ -1,407 +0,0 @@ -#include -#include -#include -#include - -#include -#include - -#include "spdlog/spdlog.h" -#include "spdlog/fmt/fmt.h" -#include "spdlog/fmt/ostr.h" - -#include "common_util.h" -#include "bifiltration.h" - -namespace spd = spdlog; - -namespace md { - - void Bifiltration::init() - { - Point lower_left = max_point(); - Point upper_right = min_point(); - for(const auto& simplex : simplices_) { - lower_left = greatest_lower_bound(lower_left, simplex.position()); - upper_right = least_upper_bound(upper_right, simplex.position()); - maximal_dim_ = std::max(maximal_dim_, simplex.dim()); - } - bounding_box_ = Box(lower_left, upper_right); - } - - Bifiltration::Bifiltration(const std::string& fname ) - { - std::ifstream ifstr {fname.c_str()}; - if (!ifstr.good()) { - std::string error_message = fmt::format("Cannot open file {0}", fname); - std::cerr << error_message << std::endl; - throw std::runtime_error(error_message); - } - - BifiltrationFormat input_format; - - std::string s; - - while(ignore_line(s)) { - std::getline(ifstr, s); - } - - if (s == "bifiltration") { - input_format = BifiltrationFormat::rivet; - } else if (s == "bifiltration_phat_like") { - input_format = BifiltrationFormat::phat_like; - } else { - std::cerr << "Unknown format: '" << s << "' in file " << fname << std::endl; - throw std::runtime_error("unknown bifiltration format"); - } - - switch(input_format) { - case BifiltrationFormat::rivet : - rivet_format_reader(ifstr); - break; - case BifiltrationFormat::phat_like : - phat_like_format_reader(ifstr); - break; - } - - ifstr.close(); - - init(); - } - - void Bifiltration::rivet_format_reader(std::ifstream& ifstr) - { - std::string s; - // read axes names - std::getline(ifstr, parameter_1_name_); - std::getline(ifstr, parameter_2_name_); - - Index index = 0; - while(std::getline(ifstr, s)) { - if (!ignore_line(s)) { - simplices_.emplace_back(index++, s, BifiltrationFormat::rivet); - } - } - } - - void Bifiltration::phat_like_format_reader(std::ifstream& ifstr) - { - spd::debug("Enter phat_like_format_reader"); - // read stream line by line; do not use >> operator - std::string s; - std::getline(ifstr, s); - - // first line contains number of simplices - long n_simplices = std::stol(s); - - // all other lines represent a simplex - Index index = 0; - while(index < n_simplices) { - std::getline(ifstr, s); - if (!ignore_line(s)) { - simplices_.emplace_back(index++, s, BifiltrationFormat::phat_like); - } - } - spd::debug("Read {} simplices from file", n_simplices); - } - - void Bifiltration::scale(Real lambda) - { - for(auto& s : simplices_) { - s.scale(lambda); - } - init(); - } - - void Bifiltration::sanity_check() const - { -#ifdef DEBUG - spd::debug("Enter Bifiltration::sanity_check"); - // check that boundary has correct number of simplices, - // each bounding simplex has correct dim - // and appears in the filtration before the simplex it bounds - for(const auto& s : simplices_) { - assert(s.dim() >= 0); - assert(s.dim() == 0 or s.dim() + 1 == (int) s.boundary().size()); - for(auto bdry_idx : s.boundary()) { - Simplex bdry_simplex = simplices()[bdry_idx]; - assert(bdry_simplex.dim() == s.dim() - 1); - assert(bdry_simplex.position().is_less(s.position(), false)); - } - } - spd::debug("Exit Bifiltration::sanity_check"); -#endif - } - - Diagram Bifiltration::weighted_slice_diagram(const DualPoint& line, int dim) const - { - DiagramKeeper dgm; - - // make a copy for now; I want slice_diagram to be const - std::vector simplices(simplices_); - -// std::vector simplices; -// simplices.reserve(simplices_.size() / 2); -// for(const auto& s : simplices_) { -// if (s.dim() <= dim + 1 and s.dim() >= dim) -// simplices.emplace_back(s); -// } - - spd::debug("Enter slice diagram, line = {}, simplices.size = {}", line, simplices.size()); - - for(auto& simplex : simplices) { - Real value = line.weighted_push(simplex.position()); -// spd::debug("in slice_diagram, simplex = {}, value = {}\n", simplex, value); - simplex.set_value(value); - } - - std::sort(simplices.begin(), simplices.end(), - [](const Simplex& a, const Simplex& b) { return a.value() < b.value(); }); - std::map index_map; - for(Index i = 0; i < (int) simplices.size(); i++) { - index_map[simplices[i].id()] = i; - } - - phat::boundary_matrix<> phat_matrix; - phat_matrix.set_num_cols(simplices.size()); - std::vector bd_in_slice_filtration; - for(Index i = 0; i < (int) simplices.size(); i++) { - phat_matrix.set_dim(i, simplices[i].dim()); - bd_in_slice_filtration.clear(); - //std::cout << "new col" << i << std::endl; - for(int j = 0; j < (int) simplices[i].boundary().size(); j++) { - // F[i] contains the indices of its facet wrt to the - // original filtration. We have to express it, however, - // wrt to the filtration along the slice. That is why - // we need the index_map - //std::cout << "Found " << F[i].bd[j] << ", returning " << index_map[F[i].bd[j]] << std::endl; - bd_in_slice_filtration.push_back(index_map[simplices[i].boundary()[j]]); - } - std::sort(bd_in_slice_filtration.begin(), bd_in_slice_filtration.end()); - phat_matrix.set_col(i, bd_in_slice_filtration); - } - phat::persistence_pairs phat_persistence_pairs; - phat::compute_persistence_pairs(phat_persistence_pairs, phat_matrix); - - dgm.clear(); - constexpr Real real_inf = std::numeric_limits::infinity(); - for(long i = 0; i < (long) phat_persistence_pairs.get_num_pairs(); i++) { - std::pair new_pair = phat_persistence_pairs.get_pair(i); - bool is_finite_pair = new_pair.second != phat::k_infinity_index; - Real birth = simplices.at(new_pair.first).value(); - Real death = is_finite_pair ? simplices.at(new_pair.second).value() : real_inf; - int dim = simplices[new_pair.first].dim(); - assert(dim + 1 == simplices[new_pair.second].dim()); - if (birth != death) { - dgm.add_point(dim, birth, death); - } - } - - spdlog::debug("Exiting slice_diagram, #dgm[0] = {}", dgm.get_diagram(0).size()); - - return dgm.get_diagram(dim); - } - - Box Bifiltration::bounding_box() const - { - return bounding_box_; - } - - Real Bifiltration::minimal_coordinate() const - { - return std::min(bounding_box_.lower_left().x, bounding_box_.lower_left().y); - } - - void Bifiltration::translate(Real a) - { - bounding_box_.translate(a); - for(auto& simplex : simplices_) { - simplex.translate(a); - } - } - - Real Bifiltration::max_x() const - { - if (simplices_.empty()) - return 1; - auto me = std::max_element(simplices_.cbegin(), simplices_.cend(), - [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; }); - assert(me != simplices_.cend()); - return me->position().x; - } - - Real Bifiltration::max_y() const - { - if (simplices_.empty()) - return 1; - auto me = std::max_element(simplices_.cbegin(), simplices_.cend(), - [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; }); - assert(me != simplices_.cend()); - return me->position().y; - } - - Real Bifiltration::min_x() const - { - if (simplices_.empty()) - return 0; - auto me = std::min_element(simplices_.cbegin(), simplices_.cend(), - [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; }); - assert(me != simplices_.cend()); - return me->position().x; - } - - Real Bifiltration::min_y() const - { - if (simplices_.empty()) - return 0; - auto me = std::min_element(simplices_.cbegin(), simplices_.cend(), - [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; }); - assert(me != simplices_.cend()); - return me->position().y; - } - - void Bifiltration::add_simplex(md::Index _id, md::Point birth, int _dim, const md::Column& _bdry) - { - simplices_.emplace_back(_id, birth, _dim, _bdry); - } - - void Bifiltration::save(const std::string& filename, md::BifiltrationFormat format) - { - switch(format) { - case BifiltrationFormat::rivet: - throw std::runtime_error("Not implemented"); - break; - case BifiltrationFormat::phat_like: { - std::ofstream f(filename); - if (not f.good()) { - std::cerr << "Bifiltration::save: cannot open file " << filename << std::endl; - throw std::runtime_error("Cannot open file for writing "); - } - f << simplices_.size() << "\n"; - - for(const auto& s : simplices_) { - f << s.dim() << " " << s.position().x << " " << s.position().y << " "; - for(int b : s.boundary()) { - f << b << " "; - } - f << std::endl; - } - - } - break; - } - } - - void Bifiltration::postprocess_rivet_format() - { - std::map facets_to_ids; - - // fill the map - for(Index i = 0; i < (Index) simplices_.size(); ++i) { - assert(simplices_[i].id() == i); - facets_to_ids[simplices_[i].vertices_] = i; - } - -// for(const auto& s : simplices_) { -// facets_to_ids[s] = s.id(); -// } - - // main loop - for(auto& s : simplices_) { - assert(not s.vertices_.empty()); - assert(s.facet_indices_.empty()); - Column facet_indices; - for(Index i = 0; i <= s.dim(); ++i) { - Column facet; - for(Index j : s.vertices_) { - if (j != i) - facet.push_back(j); - } - auto facet_index = facets_to_ids.at(facet); - facet_indices.push_back(facet_index); - } - s.facet_indices_ = facet_indices; - } // loop over simplices - } - - std::ostream& operator<<(std::ostream& os, const Bifiltration& bif) - { - os << "Bifiltration, axes = " << bif.parameter_1_name_ << ", " << bif.parameter_2_name_ << std::endl; - for(const auto& s : bif.simplices()) { - os << s << std::endl; - } - return os; - } - - BifiltrationProxy::BifiltrationProxy(const md::Bifiltration& bif, int dim) - : - dim_(dim), - bif_(bif) - { - cache_positions(); - } - - void BifiltrationProxy::cache_positions() const - { - cached_positions_.clear(); - for(const auto& simplex : bif_.simplices()) { - if (simplex.dim() == dim_ or simplex.dim() == dim_ + 1) - cached_positions_.push_back(simplex.position()); - } - } - - PointVec BifiltrationProxy::positions() const - { - if (cached_positions_.empty()) { - cache_positions(); - } - return cached_positions_; - } - - // translate all points by vector (a,a) - void BifiltrationProxy::translate(Real a) - { - bif_.translate(a); - } - - // return minimal value of x- and y-coordinates - // among all simplices - Real BifiltrationProxy::minimal_coordinate() const - { - return bif_.minimal_coordinate(); - } - - // return box that contains positions of all simplices - Box BifiltrationProxy::bounding_box() const - { - return bif_.bounding_box(); - } - - Real BifiltrationProxy::max_x() const - { - return bif_.max_x(); - } - - Real BifiltrationProxy::max_y() const - { - return bif_.max_y(); - } - - Real BifiltrationProxy::min_x() const - { - return bif_.min_x(); - } - - Real BifiltrationProxy::min_y() const - { - return bif_.min_y(); - } - - - Diagram BifiltrationProxy::weighted_slice_diagram(const DualPoint& slice) const - { - return bif_.weighted_slice_diagram(slice, dim_); - } - -} - diff --git a/matching/src/box.cpp b/matching/src/box.cpp deleted file mode 100644 index c128698..0000000 --- a/matching/src/box.cpp +++ /dev/null @@ -1,61 +0,0 @@ - -#include "box.h" - -namespace md { - - std::ostream& operator<<(std::ostream& os, const Box& box) - { - os << "Box(lower_left = " << box.lower_left() << ", upper_right = " << box.upper_right() << ")"; - return os; - } - - Box get_enclosing_box(const Box& box_a, const Box& box_b) - { - Point lower_left(std::min(box_a.lower_left().x, box_b.lower_left().x), - std::min(box_a.lower_left().y, box_b.lower_left().y)); - Point upper_right(std::max(box_a.upper_right().x, box_b.upper_right().x), - std::max(box_a.upper_right().y, box_b.upper_right().y)); - return Box(lower_left, upper_right); - } - - void Box::translate(md::Real a) - { - ll.x += a; - ll.y += a; - ur.x += a; - ur.y += a; - } - - std::vector Box::refine() const - { - std::vector result; - -// 1 | 2 -// 0 | 3 - - Point new_ll = lower_left(); - Point new_ur = center(); - result.emplace_back(new_ll, new_ur); - - new_ll.y = center().y; - new_ur.y = ur.y; - result.emplace_back(new_ll, new_ur); - - new_ll = center(); - new_ur = upper_right(); - result.emplace_back(new_ll, new_ur); - - new_ll.y = ll.y; - new_ur.y = center().y; - result.emplace_back(new_ll, new_ur); - - return result; - } - - std::vector Box::corners() const - { - return {ll, Point(ll.x, ur.y), ur, Point(ur.x, ll.y)}; - }; - - -} diff --git a/matching/src/cell_with_value.cpp b/matching/src/cell_with_value.cpp deleted file mode 100644 index d8fd7d4..0000000 --- a/matching/src/cell_with_value.cpp +++ /dev/null @@ -1,247 +0,0 @@ -#include -#include - -namespace spd = spdlog; - -#include "cell_with_value.h" - -namespace md { - -#ifdef MD_DEBUG - long long int CellWithValue::max_id = 0; -#endif - - Real CellWithValue::value_at(ValuePoint vp) const - { - switch(vp) { - case ValuePoint::upper_left : - return upper_left_value_; - case ValuePoint::upper_right : - return upper_right_value_; - case ValuePoint::lower_left : - return lower_left_value_; - case ValuePoint::lower_right : - return lower_right_value_; - case ValuePoint::center: - return central_value_; - } - // to shut up compiler warning - return 1.0 / 0.0; - } - - bool CellWithValue::has_value_at(ValuePoint vp) const - { - switch(vp) { - case ValuePoint::upper_left : - return upper_left_value_ >= 0; - case ValuePoint::upper_right : - return upper_right_value_ >= 0; - case ValuePoint::lower_left : - return lower_left_value_ >= 0; - case ValuePoint::lower_right : - return lower_right_value_ >= 0; - case ValuePoint::center: - return central_value_ >= 0; - } - // to shut up compiler warning - return 1.0 / 0.0; - } - - DualPoint CellWithValue::value_point(md::ValuePoint vp) const - { - switch(vp) { - case ValuePoint::upper_left : - return dual_box().upper_left(); - case ValuePoint::upper_right : - return dual_box().upper_right(); - case ValuePoint::lower_left : - return dual_box().lower_left(); - case ValuePoint::lower_right : - return dual_box().lower_right(); - case ValuePoint::center: - return dual_box().center(); - } - // to shut up compiler warning - return DualPoint(); - } - - bool CellWithValue::has_corner_value() const - { - return has_lower_left_value() or has_lower_right_value() or has_upper_left_value() - or has_upper_right_value(); - } - - Real CellWithValue::stored_upper_bound() const - { - assert(has_max_possible_value_); - return max_possible_value_; - } - - Real CellWithValue::max_corner_value() const - { - return std::max({lower_left_value_, lower_right_value_, upper_left_value_, upper_right_value_}); - } - - Real CellWithValue::min_value() const - { - Real result = std::numeric_limits::max(); - for(auto vp : k_all_vps) { - if (not has_value_at(vp)) { - continue; - } - result = std::min(result, value_at(vp)); - } - return result; - } - - std::vector CellWithValue::get_refined_cells() const - { - std::vector result; - result.reserve(4); - for(const auto& refined_box : dual_box_.refine()) { - - CellWithValue refined_cell(refined_box, level() + 1); - -#ifdef MD_DEBUG - refined_cell.parent_ids = parent_ids; - refined_cell.parent_ids.push_back(id); - refined_cell.id = ++max_id; -#endif - - if (refined_box.lower_left() == dual_box_.lower_left()) { - // _|_ - // H|_ - - refined_cell.set_value_at(ValuePoint::lower_left, lower_left_value_); - refined_cell.set_value_at(ValuePoint::upper_right, central_value_); - - } else if (refined_box.upper_right() == dual_box_.upper_right()) { - // _|H - // _|_ - - refined_cell.set_value_at(ValuePoint::lower_left, central_value_); - refined_cell.set_value_at(ValuePoint::upper_right, upper_right_value_); - - } else if (refined_box.lower_right() == dual_box_.lower_right()) { - // _|_ - // _|H - - refined_cell.set_value_at(ValuePoint::lower_right, lower_right_value_); - refined_cell.set_value_at(ValuePoint::upper_left, central_value_); - - } else if (refined_box.upper_left() == dual_box_.upper_left()) { - - // H|_ - // _|_ - - refined_cell.set_value_at(ValuePoint::lower_right, central_value_); - refined_cell.set_value_at(ValuePoint::upper_left, upper_left_value_); - } - result.emplace_back(refined_cell); - } - return result; - } - - void CellWithValue::set_value_at(md::ValuePoint vp, md::Real new_value) - { - if (has_value_at(vp)) - spd::error("CellWithValue: trying to re-assign value!, this = {}, vp = {}", *this, vp); - - switch(vp) { - case ValuePoint::upper_left : - upper_left_value_ = new_value; - break; - case ValuePoint::upper_right : - upper_right_value_ = new_value; - break; - case ValuePoint::lower_left : - lower_left_value_ = new_value; - break; - case ValuePoint::lower_right : - lower_right_value_ = new_value; - break; - case ValuePoint::center: - central_value_ = new_value; - break; - } - - - } - - int CellWithValue::num_values() const - { - int result = 0; - for(ValuePoint vp : k_all_vps) { - result += has_value_at(vp); - } - return result; - } - - - void CellWithValue::set_max_possible_value(Real new_upper_bound) - { - assert(new_upper_bound >= central_value_); - assert(new_upper_bound >= lower_left_value_); - assert(new_upper_bound >= lower_right_value_); - assert(new_upper_bound >= upper_left_value_); - assert(new_upper_bound >= upper_right_value_); - has_max_possible_value_ = true; - max_possible_value_ = new_upper_bound; - } - - std::ostream& operator<<(std::ostream& os, const ValuePoint& vp) - { - switch(vp) { - case ValuePoint::upper_left : - os << "upper_left"; - break; - case ValuePoint::upper_right : - os << "upper_right"; - break; - case ValuePoint::lower_left : - os << "lower_left"; - break; - case ValuePoint::lower_right : - os << "lower_right"; - break; - case ValuePoint::center: - os << "center"; - break; - default: - os << "FORGOTTEN ValuePoint"; - } - return os; - } - - - - std::ostream& operator<<(std::ostream& os, const CellWithValue& cell) - { - os << "CellWithValue(box = " << cell.dual_box() << ", "; - -#ifdef MD_DEBUG - os << "id = " << cell.id; - if (not cell.parent_ids.empty()) - os << ", parent_ids = " << container_to_string(cell.parent_ids) << ", "; -#endif - - for(ValuePoint vp : k_all_vps) { - if (cell.has_value_at(vp)) { - os << "value = " << cell.value_at(vp); - os << ", at " << vp << " " << cell.value_point(vp); - } - } - - os << ", max_corner_value = "; - if (cell.has_max_possible_value()) { - os << cell.stored_upper_bound(); - } else { - os << "-"; - } - - os << ", level = " << cell.level() << ")"; - return os; - } - -} // namespace md - diff --git a/matching/src/common_util.cpp b/matching/src/common_util.cpp deleted file mode 100644 index 96c3388..0000000 --- a/matching/src/common_util.cpp +++ /dev/null @@ -1,243 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include - -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" - -namespace md { - - - int gcd(int a, int b) - { - assert(a != 0 or b != 0); - // make b <= a - std::tie(b, a) = std::minmax({ abs(a), abs(b) }); - if (b == 0) - return a; - while((a = a % b)) { - std::swap(a, b); - } - return b; - } - - int signum(int a) - { - if (a < 0) - return -1; - else if (a > 0) - return 1; - else - return 0; - } - - Rational reduce(Rational frac) - { - int d = gcd(frac.numerator, frac.denominator); - frac.numerator /= d; - frac.denominator /= d; - return frac; - } - - void Rational::reduce() { *this = md::reduce(*this); } - - - Rational& Rational::operator*=(const md::Rational& rhs) - { - numerator *= rhs.numerator; - denominator *= rhs.denominator; - reduce(); - return *this; - } - - Rational& Rational::operator/=(const md::Rational& rhs) - { - numerator *= rhs.denominator; - denominator *= rhs.numerator; - reduce(); - return *this; - } - - Rational& Rational::operator+=(const md::Rational& rhs) - { - numerator = numerator * rhs.denominator + denominator * rhs.numerator; - denominator *= rhs.denominator; - reduce(); - return *this; - } - - Rational& Rational::operator-=(const md::Rational& rhs) - { - numerator = numerator * rhs.denominator - denominator * rhs.numerator; - denominator *= rhs.denominator; - reduce(); - return *this; - } - - - Rational midpoint(Rational a, Rational b) - { - return reduce({a.numerator * b.denominator + a.denominator * b.numerator, 2 * a.denominator * b.denominator }); - } - - Rational operator+(Rational a, const Rational& b) - { - a += b; - return a; - } - - Rational operator-(Rational a, const Rational& b) - { - a -= b; - return a; - } - - Rational operator*(Rational a, const Rational& b) - { - a *= b; - return a; - } - - Rational operator/(Rational a, const Rational& b) - { - a /= b; - return a; - } - - bool is_less(Rational a, Rational b) - { - // compute a - b = a_1 / a_2 - b_1 / b_2 - long numer = a.numerator * b.denominator - a.denominator * b.numerator; - long denom = a.denominator * b.denominator; - assert(denom != 0); - return signum(numer) * signum(denom) < 0; - } - - bool operator==(const Rational& a, const Rational& b) - { - return std::tie(a.numerator, a.denominator) == std::tie(b.numerator, b.denominator); - } - - bool operator<(const Rational& a, const Rational& b) - { - // do not remove signum - overflow - long numer = a.numerator * b.denominator - a.denominator * b.numerator; - long denom = a.denominator * b.denominator; - assert(denom != 0); -// spdlog::debug("a = {}, b = {}, numer = {}, denom = {}, result = {}", a, b, numer, denom, signum(numer) * signum(denom) <= 0); - return signum(numer) * signum(denom) < 0; - } - - bool is_leq(Rational a, Rational b) - { - // compute a - b = a_1 / a_2 - b_1 / b_2 - long numer = a.numerator * b.denominator - a.denominator * b.numerator; - long denom = a.denominator * b.denominator; - assert(denom != 0); - return signum(numer) * signum(denom) <= 0; - } - - bool is_greater(Rational a, Rational b) - { - return not is_leq(a, b); - } - - bool is_geq(Rational a, Rational b) - { - return not is_less(a, b); - } - - Point operator+(const Point& u, const Point& v) - { - return Point(u.x + v.x, u.y + v.y); - } - - Point operator-(const Point& u, const Point& v) - { - return Point(u.x - v.x, u.y - v.y); - } - - Point least_upper_bound(const Point& u, const Point& v) - { - return Point(std::max(u.x, v.x), std::max(u.y, v.y)); - } - - Point greatest_lower_bound(const Point& u, const Point& v) - { - return Point(std::min(u.x, v.x), std::min(u.y, v.y)); - } - - Point max_point() - { - return Point(std::numeric_limits::max(), std::numeric_limits::min()); - } - - Point min_point() - { - return Point(-std::numeric_limits::max(), -std::numeric_limits::min()); - } - - std::ostream& operator<<(std::ostream& ostr, const Point& vec) - { - ostr << "(" << vec.x << ", " << vec.y << ")"; - return ostr; - } - - Real l_infty_norm(const Point& v) - { - return std::max(std::abs(v.x), std::abs(v.y)); - } - - Real l_2_norm(const Point& v) - { - return v.norm(); - } - - Real l_2_dist(const Point& x, const Point& y) - { - return l_2_norm(x - y); - } - - Real l_infty_dist(const Point& x, const Point& y) - { - return l_infty_norm(x - y); - } - - void DiagramKeeper::add_point(int dim, md::Real birth, md::Real death) - { - data_[dim].emplace_back(birth, death); - } - - DiagramKeeper::Diagram DiagramKeeper::get_diagram(int dim) const - { - if (data_.count(dim) == 1) - return data_.at(dim); - else - return DiagramKeeper::Diagram(); - } - - // return true, if line starts with # - // or contains only spaces - bool ignore_line(const std::string& s) - { - for(auto c : s) { - if (isspace(c)) - continue; - return (c == '#'); - } - return true; - } - - - - std::ostream& operator<<(std::ostream& os, const Rational& a) - { - os << a.numerator << " / " << a.denominator; - return os; - } -} diff --git a/matching/src/dual_box.cpp b/matching/src/dual_box.cpp deleted file mode 100644 index ff4d30c..0000000 --- a/matching/src/dual_box.cpp +++ /dev/null @@ -1,194 +0,0 @@ -#include - -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" - -namespace spd = spdlog; - -#include "dual_box.h" - -namespace md { - - std::ostream& operator<<(std::ostream& os, const DualBox& db) - { - os << "DualBox(" << db.lower_left_ << ", " << db.upper_right_ << ")"; - return os; - } - - DualBox::DualBox(DualPoint ll, DualPoint ur) - :lower_left_(ll), upper_right_(ur) - { - } - - std::vector DualBox::corners() const - { - return {lower_left_, - DualPoint(axis_type(), angle_type(), lower_left_.lambda(), upper_right_.mu()), - upper_right_, - DualPoint(axis_type(), angle_type(), upper_right_.lambda(), lower_left_.mu())}; - } - - std::vector DualBox::push_change_points(const Point& p) const - { - std::vector result; - result.reserve(2); - - bool is_y_type = lower_left_.is_y_type(); - bool is_flat = lower_left_.is_flat(); - - auto mu_from_lambda = [p, is_y_type, is_flat](Real lambda) { - bool is_x_type = not is_y_type, is_steep = not is_flat; - if (is_y_type and is_flat) { - return p.y - lambda * p.x; - } else if (is_y_type and is_steep) { - return p.y - p.x / lambda; - } else if (is_x_type and is_flat) { - return p.x - p.y / lambda; - } else if (is_x_type and is_steep) { - return p.x - lambda * p.y; - } - // to shut up compiler warning - return static_cast(1.0 / 0.0); - }; - - auto lambda_from_mu = [p, is_y_type, is_flat](Real mu) { - bool is_x_type = not is_y_type, is_steep = not is_flat; - if (is_y_type and is_flat) { - return (p.y - mu) / p.x; - } else if (is_y_type and is_steep) { - return p.x / (p.y - mu); - } else if (is_x_type and is_flat) { - return p.y / (p.x - mu); - } else if (is_x_type and is_steep) { - return (p.x - mu) / p.y; - } - // to shut up compiler warning - return static_cast(1.0 / 0.0); - }; - - // all inequalities below are strict: equality means it is a corner - // and critical_points() returns corners anyway - - Real mu_intersect_min = mu_from_lambda(lambda_min()); - - if (mu_min() < mu_intersect_min && mu_intersect_min < mu_max()) - result.emplace_back(axis_type(), angle_type(), lambda_min(), mu_intersect_min); - - Real mu_intersect_max = mu_from_lambda(lambda_max()); - - if (mu_max() < mu_intersect_max && mu_intersect_max < mu_max()) - result.emplace_back(axis_type(), angle_type(), lambda_max(), mu_intersect_max); - - Real lambda_intersect_min = lambda_from_mu(mu_min()); - - if (lambda_min() < lambda_intersect_min && lambda_intersect_min < lambda_max()) - result.emplace_back(axis_type(), angle_type(), lambda_intersect_min, mu_min()); - - Real lambda_intersect_max = lambda_from_mu(mu_max()); - if (lambda_min() < lambda_intersect_max && lambda_intersect_max < lambda_max()) - result.emplace_back(axis_type(), angle_type(), lambda_intersect_max, mu_max()); - - assert(result.size() <= 2); - - if (result.size() > 2) { - fmt::print("Error in push_change_points, p = {}, dual_box = {}, result = {}\n", p, *this, - container_to_string(result)); - throw std::runtime_error("push_change_points returned more than 2 points"); - } - - return result; - } - - std::vector DualBox::critical_points(const Point& /*p*/) const - { - // maximal difference is attained at corners - return corners(); -// std::vector result; -// result.reserve(6); -// for(auto dp : corners()) result.push_back(dp); -// for(auto dp : push_change_points(p)) result.push_back(dp); -// return result; - } - - std::vector DualBox::random_points(int n) const - { - assert(n >= 0); - std::mt19937_64 gen(1); - std::vector result; - result.reserve(n); - std::uniform_real_distribution mu_distr(mu_min(), mu_max()); - std::uniform_real_distribution lambda_distr(lambda_min(), lambda_max()); - for(int i = 0; i < n; ++i) { - result.emplace_back(axis_type(), angle_type(), lambda_distr(gen), mu_distr(gen)); - } - return result; - } - - bool DualBox::sanity_check() const - { - lower_left_.sanity_check(); - upper_right_.sanity_check(); - - if (lower_left_.angle_type() != upper_right_.angle_type()) - throw std::runtime_error("angle types differ"); - - if (lower_left_.axis_type() != upper_right_.axis_type()) - throw std::runtime_error("axis types differ"); - - if (lower_left_.lambda() >= upper_right_.lambda()) - throw std::runtime_error("lambda of lower_left_ greater than lambda of upper_right "); - - if (lower_left_.mu() >= upper_right_.mu()) - throw std::runtime_error("mu of lower_left_ greater than mu of upper_right "); - - return true; - } - - std::vector DualBox::refine() const - { - std::vector result; - - result.reserve(4); - - Real lambda_middle = (lower_left().lambda() + upper_right().lambda()) / 2.0; - Real mu_middle = (lower_left().mu() + upper_right().mu()) / 2.0; - - DualPoint refinement_center(axis_type(), angle_type(), lambda_middle, mu_middle); - - result.emplace_back(lower_left_, refinement_center); - - result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_middle, mu_min()), - DualPoint(axis_type(), angle_type(), lambda_max(), mu_middle)); - - result.emplace_back(refinement_center, upper_right_); - - result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_min(), mu_middle), - DualPoint(axis_type(), angle_type(), lambda_middle, mu_max())); - return result; - } - - bool DualBox::operator==(const DualBox& other) const - { - return lower_left() == other.lower_left() and - upper_right() == other.upper_right(); - } - - bool DualBox::contains(const DualPoint& dp) const - { - return dp.angle_type() == angle_type() and dp.axis_type() == axis_type() and - mu_max() >= dp.mu() and - mu_min() <= dp.mu() and - lambda_min() <= dp.lambda() and - lambda_max() >= dp.lambda(); - } - - DualPoint DualBox::lower_right() const - { - return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_max(), mu_min()); - } - - DualPoint DualBox::upper_left() const - { - return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_min(), mu_max()); - } -} diff --git a/matching/src/dual_point.cpp b/matching/src/dual_point.cpp deleted file mode 100644 index 1c00b58..0000000 --- a/matching/src/dual_point.cpp +++ /dev/null @@ -1,282 +0,0 @@ -#include - -#include "dual_point.h" - -namespace md { - - std::ostream& operator<<(std::ostream& os, const AxisType& at) - { - if (at == AxisType::x_type) - os << "x-type"; - else - os << "y-type"; - return os; - } - - std::ostream& operator<<(std::ostream& os, const AngleType& at) - { - if (at == AngleType::flat) - os << "flat"; - else - os << "steep"; - return os; - } - - std::ostream& operator<<(std::ostream& os, const DualPoint& dp) - { - os << "Line(" << dp.axis_type() << ", "; - os << dp.angle_type() << ", "; - os << dp.lambda() << ", "; - os << dp.mu() << ", equation: "; - if (not dp.is_vertical()) { - os << "y = " << dp.y_slope() << " x + " << dp.y_intercept(); - } else { - os << "x = " << dp.x_intercept(); - } - os << " )"; - return os; - } - - bool DualPoint::operator<(const DualPoint& rhs) const - { - return std::tie(axis_type_, angle_type_, lambda_, mu_) - < std::tie(rhs.axis_type_, rhs.angle_type_, rhs.lambda_, rhs.mu_); - } - - DualPoint::DualPoint(AxisType axis_type, AngleType angle_type, Real lambda, Real mu) - : - axis_type_(axis_type), - angle_type_(angle_type), - lambda_(lambda), - mu_(mu) - { - assert(sanity_check()); - } - - bool DualPoint::sanity_check() const - { - if (lambda_ < 0.0) - throw std::runtime_error("Invalid line, negative lambda"); - if (lambda_ > 1.0) - throw std::runtime_error("Invalid line, lambda > 1"); - if (mu_ < 0.0) - throw std::runtime_error("Invalid line, negative mu"); - return true; - } - - Real DualPoint::gamma() const - { - if (is_steep()) - return atan(Real(1.0) / lambda_); - else - return atan(lambda_); - } - - DualPoint midpoint(DualPoint x, DualPoint y) - { - assert(x.angle_type() == y.angle_type() and x.axis_type() == y.axis_type()); - Real lambda_mid = (x.lambda() + y.lambda()) / 2; - Real mu_mid = (x.mu() + y.mu()) / 2; - return DualPoint(x.axis_type(), x.angle_type(), lambda_mid, mu_mid); - - } - - // return k in the line equation y = kx + b - Real DualPoint::y_slope() const - { - if (is_flat()) - return lambda(); - else - return Real(1.0) / lambda(); - } - - // return k in the line equation x = ky + b - Real DualPoint::x_slope() const - { - if (is_flat()) - return Real(1.0) / lambda(); - else - return lambda(); - } - - // return b in the line equation y = kx + b - Real DualPoint::y_intercept() const - { - if (is_y_type()) { - return mu(); - } else { - // x = x_slope * y + mu = x_slope * (y + mu / x_slope) - // x-intercept is -mu/x_slope = -mu * y_slope - return -mu() * y_slope(); - } - } - - // return k in the line equation x = ky + b - Real DualPoint::x_intercept() const - { - if (is_x_type()) { - return mu(); - } else { - // y = y_slope * x + mu = y_slope (x + mu / y_slope) - // x_intercept is -mu/y_slope = -mu * x_slope - return -mu() * x_slope(); - } - } - - Real DualPoint::x_from_y(Real y) const - { - if (is_horizontal()) - throw std::runtime_error("x_from_y called on horizontal line"); - else - return x_slope() * y + x_intercept(); - } - - Real DualPoint::y_from_x(Real x) const - { - if (is_vertical()) - throw std::runtime_error("x_from_y called on horizontal line"); - else - return y_slope() * x + y_intercept(); - } - - bool DualPoint::is_horizontal() const - { - return is_flat() and lambda() == 0; - } - - bool DualPoint::is_vertical() const - { - return is_steep() and lambda() == 0; - } - - bool DualPoint::contains(Point p) const - { - if (is_vertical()) - return p.x == x_from_y(p.y); - else - return p.y == y_from_x(p.x); - } - - bool DualPoint::goes_below(Point p) const - { - if (is_vertical()) - return p.x <= x_from_y(p.y); - else - return p.y >= y_from_x(p.x); - } - - bool DualPoint::goes_above(Point p) const - { - if (is_vertical()) - return p.x >= x_from_y(p.y); - else - return p.y <= y_from_x(p.x); - } - - Point DualPoint::push(Point p) const - { - Point result; - // if line is below p, we push horizontally - bool horizontal_push = goes_below(p); - if (is_x_type()) { - if (is_flat()) { - if (horizontal_push) { - result.x = p.y / lambda() + mu(); - result.y = p.y; - } else { - // vertical push - result.x = p.x; - result.y = lambda() * (p.x - mu()); - } - } else { - // steep - if (horizontal_push) { - result.x = lambda() * p.y + mu(); - result.y = p.y; - } else { - // vertical push - result.x = p.x; - result.y = (p.x - mu()) / lambda(); - } - } - } else { - // y-type - if (is_flat()) { - if (horizontal_push) { - result.x = (p.y - mu()) / lambda(); - result.y = p.y; - } else { - // vertical push - result.x = p.x; - result.y = lambda() * p.x + mu(); - } - } else { - // steep - if (horizontal_push) { - result.x = (p.y - mu()) * lambda(); - result.y = p.y; - } else { - // vertical push - result.x = p.x; - result.y = p.x / lambda() + mu(); - } - } - } - return result; - } - - Real DualPoint::weighted_push(Point p) const - { - // if line is below p, we push horizontally - bool horizontal_push = goes_below(p); - if (is_x_type()) { - if (is_flat()) { - if (horizontal_push) { - return p.y; - } else { - // vertical push - return lambda() * (p.x - mu()); - } - } else { - // steep - if (horizontal_push) { - return lambda() * p.y; - } else { - // vertical push - return (p.x - mu()); - } - } - } else { - // y-type - if (is_flat()) { - if (horizontal_push) { - return p.y - mu(); - } else { - // vertical push - return lambda() * p.x; - } - } else { - // steep - if (horizontal_push) { - return lambda() * (p.y - mu()); - } else { - // vertical push - return p.x; - } - } - } - } - - bool DualPoint::operator==(const DualPoint& other) const - { - return axis_type() == other.axis_type() and - angle_type() == other.angle_type() and - mu() == other.mu() and - lambda() == other.lambda(); - } - - Real DualPoint::weight() const - { - return lambda_ / sqrt(1 + lambda_ * lambda_); - } -} // namespace md diff --git a/matching/src/main.cpp b/matching/src/main.cpp index f1472be..2093457 100644 --- a/matching/src/main.cpp +++ b/matching/src/main.cpp @@ -18,12 +18,20 @@ #include "box.h" #include "matching_distance.h" +using Real = double; + using namespace md; namespace fs = std::experimental::filesystem; +void force_instantiation() +{ + DualBox db; + std::cout << db; +} + #ifdef PRINT_HEAT_MAP -void print_heat_map(const md::HeatMaps& hms, std::string fname, const CalculationParams& params) +void print_heat_map(const md::HeatMaps& hms, std::string fname, const CalculationParams& params) { spd::debug("Entered print_heat_map"); std::set mu_vals, lambda_vals; @@ -143,7 +151,7 @@ int main(int argc, char** argv) bool help = false; bool no_stop_asap = false; - CalculationParams params; + CalculationParams params; #ifdef PRINT_HEAT_MAP bool heatmap_only = false; @@ -178,8 +186,8 @@ int main(int argc, char** argv) auto bounds_list = split_by_delim(bounds_list_str, ','); auto traverse_list = split_by_delim(traverse_list_str, ','); - Bifiltration bif_a(fname_a); - Bifiltration bif_b(fname_b); + Bifiltration bif_a(fname_a); + Bifiltration bif_b(fname_b); bif_a.sanity_check(); bif_b.sanity_check(); @@ -207,11 +215,11 @@ int main(int argc, char** argv) } struct ExperimentResult { - CalculationParams params {CalculationParams()}; + CalculationParams params {CalculationParams()}; int n_hera_calls {0}; double total_milliseconds_elapsed {0}; - double distance {0}; - double actual_error {std::numeric_limits::max()}; + Real distance {0}; + Real actual_error {std::numeric_limits::max()}; int actual_max_depth {0}; int x_wins {0}; @@ -250,7 +258,7 @@ int main(int argc, char** argv) ExperimentResult() { } - ExperimentResult(CalculationParams p, int nhc, double tme, double d) + ExperimentResult(CalculationParams p, int nhc, double tme, double d) : params(p), n_hera_calls(nhc), total_milliseconds_elapsed(tme), distance(d) { } }; @@ -267,7 +275,7 @@ int main(int argc, char** argv) std::map, ExperimentResult> results; for(BoundStrategy bound_strategy : bound_strategies) { for(TraverseStrategy traverse_strategy : traverse_strategies) { - CalculationParams params_experiment; + CalculationParams params_experiment; params_experiment.bound_strategy = bound_strategy; params_experiment.traverse_strategy = traverse_strategy; params_experiment.max_depth = params.max_depth; @@ -366,8 +374,9 @@ int main(int argc, char** argv) spd::debug("Will use {} bound, {} traverse strategy", params.bound_strategy, params.traverse_strategy); - Real dist = matching_distance(bif_a, bif_b, params); + Real dist = matching_distance(bif_a, bif_b, params); std::cout << dist << std::endl; #endif + force_instantiation(); return 0; } diff --git a/matching/src/matching_distance.cpp b/matching/src/matching_distance.cpp deleted file mode 100644 index e53233f..0000000 --- a/matching/src/matching_distance.cpp +++ /dev/null @@ -1,150 +0,0 @@ -#include -#include -#include - -#include "common_defs.h" - -#include "matching_distance.h" - -namespace md { - - Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, - CalculationParams& params) - { - Real result; - // compute distance only in one dimension - if (params.dim != CalculationParams::ALL_DIMENSIONS) { - BifiltrationProxy bifp_a(bif_a, params.dim); - BifiltrationProxy bifp_b(bif_b, params.dim); - DistanceCalculator runner(bifp_a, bifp_b, params); - result = runner.distance(); - params.n_hera_calls = runner.get_hera_calls_number(); - } else { - // compute distance in all dimensions, return maximal - result = -1; - for(int dim = 0; dim < std::max(bif_a.maximal_dim(), bif_b.maximal_dim()); ++dim) { - BifiltrationProxy bifp_a(bif_a, params.dim); - BifiltrationProxy bifp_b(bif_a, params.dim); - DistanceCalculator runner(bifp_a, bifp_b, params); - result = std::max(result, runner.distance()); - params.n_hera_calls += runner.get_hera_calls_number(); - } - } - return result; - } - - - Real matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b, - CalculationParams& params) - { - DistanceCalculator runner(mod_a, mod_b, params); - Real result = runner.distance(); - params.n_hera_calls = runner.get_hera_calls_number(); - return result; - } - - std::istream& operator>>(std::istream& is, BoundStrategy& s) - { - std::string ss; - is >> ss; - if (ss == "bruteforce") { - s = BoundStrategy::bruteforce; - } else if (ss == "local_grob") { - s = BoundStrategy::local_dual_bound; - } else if (ss == "local_combined") { - s = BoundStrategy::local_combined; - } else if (ss == "local_refined") { - s = BoundStrategy::local_dual_bound_refined; - } else if (ss == "local_for_each_point") { - s = BoundStrategy::local_dual_bound_for_each_point; - } else { - throw std::runtime_error("UNKNOWN BOUND STRATEGY"); - } - return is; - } - - BoundStrategy bs_from_string(std::string s) - { - std::stringstream ss(s); - BoundStrategy result; - ss >> result; - return result; - } - - TraverseStrategy ts_from_string(std::string s) - { - std::stringstream ss(s); - TraverseStrategy result; - ss >> result; - return result; - } - - std::istream& operator>>(std::istream& is, TraverseStrategy& s) - { - std::string ss; - is >> ss; - if (ss == "DFS") { - s = TraverseStrategy::depth_first; - } else if (ss == "BFS") { - s = TraverseStrategy::breadth_first; - } else if (ss == "BFS-VAL") { - s = TraverseStrategy::breadth_first_value; - } else if (ss == "UB") { - s = TraverseStrategy::upper_bound; - } else { - throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY"); - } - return is; - } - - std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r) - { - os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound; - return os; - } - - std::ostream& operator<<(std::ostream& os, const BoundStrategy& s) - { - switch(s) { - case BoundStrategy::bruteforce : - os << "bruteforce"; - break; - case BoundStrategy::local_dual_bound : - os << "local_grob"; - break; - case BoundStrategy::local_combined : - os << "local_combined"; - break; - case BoundStrategy::local_dual_bound_refined : - os << "local_refined"; - break; - case BoundStrategy::local_dual_bound_for_each_point : - os << "local_for_each_point"; - break; - default: - os << "FORGOTTEN BOUND STRATEGY"; - } - return os; - } - - std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s) - { - switch(s) { - case TraverseStrategy::depth_first : - os << "DFS"; - break; - case TraverseStrategy::breadth_first : - os << "BFS"; - break; - case TraverseStrategy::breadth_first_value : - os << "BFS-VAL"; - break; - case TraverseStrategy::upper_bound : - os << "UB"; - break; - default: - os << "FORGOTTEN TRAVERSE STRATEGY"; - } - return os; - } -} diff --git a/matching/src/persistence_module.cpp b/matching/src/persistence_module.cpp deleted file mode 100644 index efb20ef..0000000 --- a/matching/src/persistence_module.cpp +++ /dev/null @@ -1,177 +0,0 @@ -#include -#include -#include - -#include -#include - -#include "persistence_module.h" - -namespace md { - - /** - * - * @param values vector of length n - * @return [a_1,...,a_n] such that - * 1) values[a_1] <= values[a_2] <= ... <= values[a_n] - * 2) a_1,...,a_n is a permutation of 1,..,n - */ - - template - IndexVec get_sorted_indices(const std::vector& values) - { - IndexVec result(values.size()); - std::iota(result.begin(), result.end(), 0); - std::sort(result.begin(), result.end(), - [&values](size_t a, size_t b) { return values[a] < values[b]; }); - return result; - } - - // helper function to initialize const member positions_ in ModulePresentation - PointVec - concat_gen_and_rel_positions(const PointVec& generators, const ModulePresentation::RelVec& relations) - { - std::unordered_set ps(generators.begin(), generators.end()); - for(const auto& rel : relations) { - ps.insert(rel.position_); - } - return PointVec(ps.begin(), ps.end()); - } - - - void ModulePresentation::init_boundaries() - { - max_x_ = std::numeric_limits::max(); - max_y_ = std::numeric_limits::max(); - min_x_ = -std::numeric_limits::max(); - min_y_ = -std::numeric_limits::max(); - - for(const auto& gen : positions_) { - min_x_ = std::min(gen.x, min_x_); - min_y_ = std::min(gen.y, min_y_); - max_x_ = std::max(gen.x, max_x_); - max_y_ = std::max(gen.y, max_y_); - } - - bounding_box_ = Box(Point(min_x_, min_y_), Point(max_x_, max_y_)); - } - - - ModulePresentation::ModulePresentation(const PointVec& _generators, const RelVec& _relations) : - generators_(_generators), - relations_(_relations) - { - init_boundaries(); - } - - void ModulePresentation::translate(md::Real a) - { - for(auto& g : generators_) { - g.translate(a); - } - - for(auto& r : relations_) { - r.position_.translate(a); - } - - positions_ = concat_gen_and_rel_positions(generators_, relations_); - init_boundaries(); - } - - - /** - * - * @param slice line on which generators are projected - * @param sorted_indices [a_1,...,a_n] s.t. wpush(generator[a_1]) <= wpush(generator[a_2]) <= .. - * @param projections sorted weighted pushes of generators - */ - - void - ModulePresentation::project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const - { - size_t num_gens = generators_.size(); - - RealVec gen_values; - gen_values.reserve(num_gens); - for(const auto& pos : generators_) { - gen_values.push_back(slice.weighted_push(pos)); - } - sorted_indices = get_sorted_indices(gen_values); - projections.clear(); - projections.reserve(num_gens); - for(auto i : sorted_indices) { - projections.push_back(gen_values[i]); - } - } - - void ModulePresentation::project_relations(const DualPoint& slice, IndexVec& sorted_rel_indices, - RealVec& projections) const - { - size_t num_rels = relations_.size(); - - RealVec rel_values; - rel_values.reserve(num_rels); - for(const auto& rel : relations_) { - rel_values.push_back(slice.weighted_push(rel.position_)); - } - sorted_rel_indices = get_sorted_indices(rel_values); - projections.clear(); - projections.reserve(num_rels); - for(auto i : sorted_rel_indices) { - projections.push_back(rel_values[i]); - } - } - - Diagram ModulePresentation::weighted_slice_diagram(const DualPoint& slice) const - { - IndexVec sorted_gen_indices, sorted_rel_indices; - RealVec gen_projections, rel_projections; - - project_generators(slice, sorted_gen_indices, gen_projections); - project_relations(slice, sorted_rel_indices, rel_projections); - - phat::boundary_matrix<> phat_matrix; - - phat_matrix.set_num_cols(relations_.size()); - - for(Index i = 0; i < (Index) relations_.size(); i++) { - IndexVec current_relation = relations_[sorted_rel_indices[i]].components_; - for(auto& j : current_relation) { - j = sorted_gen_indices[j]; - } - std::sort(current_relation.begin(), current_relation.end()); - phat_matrix.set_dim(i, current_relation.size()); - phat_matrix.set_col(i, current_relation); - } - - phat::persistence_pairs phat_persistence_pairs; - phat::compute_persistence_pairs(phat_persistence_pairs, phat_matrix); - - Diagram dgm; - - constexpr Real real_inf = std::numeric_limits::infinity(); - - for(Index i = 0; i < (Index) phat_persistence_pairs.get_num_pairs(); i++) { - std::pair new_pair = phat_persistence_pairs.get_pair(i); - bool is_finite_pair = new_pair.second != phat::k_infinity_index; - Real birth = gen_projections.at(new_pair.first); - Real death = is_finite_pair ? rel_projections.at(new_pair.second) : real_inf; - if (birth != death) { - dgm.emplace_back(birth, death); - } - } - - return dgm; - } - - PointVec ModulePresentation::positions() const - { - return positions_; - } - - Box ModulePresentation::bounding_box() const - { - return bounding_box_; - } - -} diff --git a/matching/src/simplex.cpp b/matching/src/simplex.cpp deleted file mode 100644 index 6b53680..0000000 --- a/matching/src/simplex.cpp +++ /dev/null @@ -1,121 +0,0 @@ -#include "simplex.h" - -namespace md { - - std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s) - { - os << "AbstractSimplex(id = " << s.id << ", vertices_ = " << container_to_string(s.vertices_) << ")"; - return os; - } - - bool operator<(const AbstractSimplex& a, const AbstractSimplex& b) - { - return a.vertices_ < b.vertices_; - } - - bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2) - { - return s1.vertices_ == s2.vertices_; - } - - void AbstractSimplex::push_back(int v) - { - vertices_.push_back(v); - std::sort(vertices_.begin(), vertices_.end()); - } - - AbstractSimplex::AbstractSimplex(std::vector vertices, bool sort) - :vertices_(vertices) - { - if (sort) - std::sort(vertices_.begin(), vertices_.end()); - } - - std::vector AbstractSimplex::facets() const - { - std::vector result; - for (int i = 0; i < static_cast(vertices_.size()); ++i) { - std::vector facet_vertices; - facet_vertices.reserve(dim()); - for (int j = 0; j < static_cast(vertices_.size()); ++j) { - if (j != i) - facet_vertices.push_back(vertices_[j]); - } - if (!facet_vertices.empty()) { - result.emplace_back(facet_vertices, false); - } - } - return result; - } - - Simplex::Simplex(md::Index id, md::Point birth, int dim, const md::Column& bdry) - : - id_(id), - pos_(birth), - dim_(dim), - facet_indices_(bdry) { } - - void Simplex::translate(Real a) - { - pos_.translate(a); - } - - void Simplex::init_rivet(std::string s) - { - auto delim_pos = s.find_first_of(";"); - assert(delim_pos > 0); - std::string vertices_str = s.substr(0, delim_pos); - std::string pos_str = s.substr(delim_pos + 1); - assert(not vertices_str.empty() and not pos_str.empty()); - // get vertices - std::stringstream vertices_ss(vertices_str); - int dim = 0; - int vertex; - while (vertices_ss >> vertex) { - dim++; - vertices_.push_back(vertex); - } - // - std::sort(vertices_.begin(), vertices_.end()); - assert(dim > 0); - - std::stringstream pos_ss(pos_str); - // TODO: get rid of 1-criticaltiy assumption - pos_ss >> pos_.x >> pos_.y; - } - - void Simplex::init_phat_like(std::string s) - { - facet_indices_.clear(); - std::stringstream ss(s); - ss >> dim_ >> pos_.x >> pos_.y; - if (dim_ > 0) { - facet_indices_.reserve(dim_ + 1); - for (int j = 0; j <= dim_; j++) { - Index k; - ss >> k; - facet_indices_.push_back(k); - } - } - } - - Simplex::Simplex(Index _id, std::string s, BifiltrationFormat input_format) - :id_(_id) - { - switch (input_format) { - case BifiltrationFormat::phat_like : - init_phat_like(s); - break; - case BifiltrationFormat::rivet : - init_rivet(s); - break; - } - } - - std::ostream& operator<<(std::ostream& os, const Simplex& x) - { - os << "Simplex(id = " << x.id() << ", dim = " << x.dim(); - os << ", boundary = " << container_to_string(x.boundary()) << ", pos = " << x.position() << ")"; - return os; - } -} diff --git a/matching/src/test_generator.cpp b/matching/src/test_generator.cpp index e8f128f..a2f0625 100644 --- a/matching/src/test_generator.cpp +++ b/matching/src/test_generator.cpp @@ -11,9 +11,12 @@ #include "common_util.h" #include "bifiltration.h" +using Real = double; using Index = md::Index; -using Point = md::Point; +using Point = md::Point; +using Bifiltration = md::Bifiltration; using Column = md::Column; +using Simplex = md::Simplex; int g_max_coord = 100; @@ -100,7 +103,7 @@ void generate_positions(const ASimplex& s, ASimplexToBirthMap& simplex_to_birth, } } -md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_simplices) +Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_simplices) { ASimplexToBirthMap simplex_to_birth; @@ -122,13 +125,13 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_ add_if_top(candidate_simplex, top_simplices); } - Point upper_bound{static_cast(g_max_coord), static_cast(g_max_coord)}; + Point upper_bound{static_cast(g_max_coord), static_cast(g_max_coord)}; for(const auto& top_simplex : top_simplices) { generate_positions(top_simplex, simplex_to_birth, upper_bound); } std::vector> simplex_birth_pairs{simplex_to_birth.begin(), simplex_to_birth.end()}; - std::vector boundaries{simplex_to_birth.size(), md::Column()}; + std::vector boundaries{simplex_to_birth.size(), Column()}; // assign ids and save boundaries int id = 0; @@ -138,7 +141,7 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_ ASimplex& simplex = simplex_birth_pairs[i].first; if (simplex.dim() == dim) { simplex.id = id++; - md::Column bdry; + Column bdry; for(auto& facet : simplex.facets()) { auto facet_iter = std::find_if(simplex_birth_pairs.begin(), simplex_birth_pairs.end(), [facet](const std::pair& sbp) { return facet == sbp.first; }); @@ -153,7 +156,7 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_ } // create vector of Simplex-es - std::vector simplices; + std::vector simplices; for(int i = 0; i < (int) simplex_birth_pairs.size(); ++i) { int id = simplex_birth_pairs[i].first.id; int dim = simplex_birth_pairs[i].first.dim(); @@ -164,13 +167,13 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_ // sort by id std::sort(simplices.begin(), simplices.end(), - [](const md::Simplex& s1, const md::Simplex& s2) { return s1.id() < s2.id(); }); + [](const Simplex& s1, const Simplex& s2) { return s1.id() < s2.id(); }); for(int i = 0; i < (int)simplices.size(); ++i) { assert(simplices[i].id() == i); assert(i == 0 || simplices[i].dim() >= simplices[i - 1].dim()); } - return md::Bifiltration(simplices.begin(), simplices.end()); + return Bifiltration(simplices.begin(), simplices.end()); } int main(int argc, char** argv) diff --git a/matching/src/tests/test_common.cpp b/matching/src/tests/test_common.cpp index c55577e..9079a56 100644 --- a/matching/src/tests/test_common.cpp +++ b/matching/src/tests/test_common.cpp @@ -8,56 +8,24 @@ #include "simplex.h" #include "matching_distance.h" -using namespace md; +//using namespace md; +using Real = double; +using Point = md::Point; +using Bifiltration = md::Bifiltration; +using BifiltrationProxy = md::BifiltrationProxy; +using CalculationParams = md::CalculationParams; +using CellWithValue = md::CellWithValue; +using DualPoint = md::DualPoint; +using DualBox = md::DualBox; +using Simplex = md::Simplex; +using AbstractSimplex = md::AbstractSimplex; +using BoundStrategy = md::BoundStrategy; +using TraverseStrategy = md::TraverseStrategy; +using AxisType = md::AxisType; +using AngleType = md::AngleType; +using ValuePoint = md::ValuePoint; +using Column = md::Column; -TEST_CASE("Rational", "[common_utils][rational]") -{ - // gcd - REQUIRE(gcd(10, 5) == 5); - REQUIRE(gcd(5, 10) == 5); - REQUIRE(gcd(5, 7) == 1); - REQUIRE(gcd(7, 5) == 1); - REQUIRE(gcd(13, 0) == 13); - REQUIRE(gcd(0, 13) == 13); - REQUIRE(gcd(16, 24) == 8); - REQUIRE(gcd(24, 16) == 8); - REQUIRE(gcd(16, 32) == 16); - REQUIRE(gcd(32, 16) == 16); - - - // reduce - REQUIRE(reduce({2, 1}) == std::make_pair(2, 1)); - REQUIRE(reduce({1, 2}) == std::make_pair(1, 2)); - REQUIRE(reduce({2, 2}) == std::make_pair(1, 1)); - REQUIRE(reduce({0, 2}) == std::make_pair(0, 1)); - REQUIRE(reduce({0, 20}) == std::make_pair(0, 1)); - REQUIRE(reduce({35, 49}) == std::make_pair(5, 7)); - REQUIRE(reduce({35, 25}) == std::make_pair(7, 5)); - - // midpoint - REQUIRE(midpoint(Rational {0, 1}, Rational {1, 2}) == std::make_pair(1, 4)); - REQUIRE(midpoint(Rational {1, 4}, Rational {1, 2}) == std::make_pair(3, 8)); - REQUIRE(midpoint(Rational {1, 2}, Rational {1, 2}) == std::make_pair(1, 2)); - REQUIRE(midpoint(Rational {1, 2}, Rational {1, 1}) == std::make_pair(3, 4)); - REQUIRE(midpoint(Rational {3, 7}, Rational {5, 14}) == std::make_pair(11, 28)); - - - // arithmetic - - REQUIRE(Rational(1, 2) + Rational(3, 5) == Rational(11, 10)); - REQUIRE(Rational(2, 5) - Rational(3, 10) == Rational(1, 10)); - REQUIRE(Rational(2, 3) * Rational(4, 7) == Rational(8, 21)); - REQUIRE(Rational(2, 3) * Rational(3, 2) == Rational(1)); - REQUIRE(Rational(2, 3) / Rational(3, 2) == Rational(4, 9)); - REQUIRE(Rational(1, 2) * Rational(3, 5) == Rational(3, 10)); - - // comparison - REQUIRE(Rational(100000, 2000000) < Rational(100001, 2000000)); - REQUIRE(!(Rational(100001, 2000000) < Rational(100000, 2000000))); - REQUIRE(!(Rational(100000, 2000000) < Rational(100000, 2000000))); - REQUIRE(Rational(-100000, 2000000) < Rational(100001, 2000000)); - REQUIRE(Rational(-100001, 2000000) < Rational(100000, 2000000)); -}; TEST_CASE("AbstractSimplex", "[abstract_simplex]") { diff --git a/matching/src/tests/test_matching_distance.cpp b/matching/src/tests/test_matching_distance.cpp index df9345e..82da530 100644 --- a/matching/src/tests/test_matching_distance.cpp +++ b/matching/src/tests/test_matching_distance.cpp @@ -11,7 +11,25 @@ #include "simplex.h" #include "matching_distance.h" -using namespace md; +using Real = double; +using Point = md::Point; +using Bifiltration = md::Bifiltration; +using BifiltrationProxy = md::BifiltrationProxy; +using CalculationParams = md::CalculationParams; +using CellWithValue = md::CellWithValue; +using DualPoint = md::DualPoint; +using DualBox = md::DualBox; +using Simplex = md::Simplex; +using AbstractSimplex = md::AbstractSimplex; +using BoundStrategy = md::BoundStrategy; +using TraverseStrategy = md::TraverseStrategy; +using AxisType = md::AxisType; +using AngleType = md::AngleType; +using ValuePoint = md::ValuePoint; +using Column = md::Column; + +using md::k_corner_vps; + namespace spd = spdlog; TEST_CASE("Different bounds", "[bounds]") @@ -40,7 +58,7 @@ TEST_CASE("Different bounds", "[bounds]") BifiltrationProxy bifp_a(bif_a, params.dim); BifiltrationProxy bifp_b(bif_b, params.dim); - DistanceCalculator calc(bifp_a, bifp_b, params); + md::DistanceCalculator calc(bifp_a, bifp_b, params); // REQUIRE(calc.max_x_ == Approx(max_x)); // REQUIRE(calc.max_y_ == Approx(max_y)); -- cgit v1.2.3 From 5ebf7142b00554b3f5d151c8b4e81b746962a5b8 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Fri, 6 Mar 2020 18:29:25 +0100 Subject: Reorganize matching_dist code, minor fixes. --- matching/CMakeLists.txt | 26 +- matching/README | 5 - matching/README.md | 5 + matching/example/matching_dist.cpp | 396 +++++++++++++++++++++ matching/include/matching_distance.h | 2 +- matching/src/main.cpp | 382 -------------------- matching/src/test_generator.cpp | 211 ----------- matching/src/tests/prism_1.bif | 28 -- matching/src/tests/prism_2.bif | 29 -- matching/src/tests/test_bifiltration.cpp | 36 -- matching/src/tests/test_bifiltration_1.txt | 1 - .../tests/test_bifiltration_full_triangle_rene.txt | 1 - matching/src/tests/test_common.cpp | 156 -------- matching/src/tests/test_list.txt | 1 - matching/src/tests/test_matching_distance.cpp | 167 --------- matching/src/tests/tests_main.cpp | 2 - matching/tests/prism_1.bif | 28 ++ matching/tests/prism_2.bif | 29 ++ matching/tests/test_bifiltration.cpp | 36 ++ matching/tests/test_bifiltration_1.txt | 1 + .../tests/test_bifiltration_full_triangle_rene.txt | 1 + matching/tests/test_common.cpp | 156 ++++++++ matching/tests/test_list.txt | 1 + matching/tests/test_matching_distance.cpp | 167 +++++++++ matching/tests/tests_main.cpp | 2 + 25 files changed, 831 insertions(+), 1038 deletions(-) delete mode 100644 matching/README create mode 100644 matching/README.md create mode 100644 matching/example/matching_dist.cpp delete mode 100644 matching/src/main.cpp delete mode 100644 matching/src/test_generator.cpp delete mode 100644 matching/src/tests/prism_1.bif delete mode 100644 matching/src/tests/prism_2.bif delete mode 100644 matching/src/tests/test_bifiltration.cpp delete mode 120000 matching/src/tests/test_bifiltration_1.txt delete mode 120000 matching/src/tests/test_bifiltration_full_triangle_rene.txt delete mode 100644 matching/src/tests/test_common.cpp delete mode 100644 matching/src/tests/test_list.txt delete mode 100644 matching/src/tests/test_matching_distance.cpp delete mode 100644 matching/src/tests/tests_main.cpp create mode 100644 matching/tests/prism_1.bif create mode 100644 matching/tests/prism_2.bif create mode 100644 matching/tests/test_bifiltration.cpp create mode 120000 matching/tests/test_bifiltration_1.txt create mode 120000 matching/tests/test_bifiltration_full_triangle_rene.txt create mode 100644 matching/tests/test_common.cpp create mode 100644 matching/tests/test_list.txt create mode 100644 matching/tests/test_matching_distance.cpp create mode 100644 matching/tests/tests_main.cpp (limited to 'matching/CMakeLists.txt') diff --git a/matching/CMakeLists.txt b/matching/CMakeLists.txt index 121e25c..3ee0f6b 100644 --- a/matching/CMakeLists.txt +++ b/matching/CMakeLists.txt @@ -22,7 +22,6 @@ set(CMAKE_CXX_STANDARD 14) if (NOT WIN32) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") - #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS} -Wall pedantic -Wextra ") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS} -Wall -Wextra ") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -ggdb -D_GLIBCXX_DEBUG") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELEASE} -O2 -g -ggdb") @@ -31,32 +30,23 @@ endif (NOT WIN32) file(GLOB BT_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.hpp) file(GLOB MD_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp) -file(GLOB SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) - -file(GLOB SRC_TEST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/tests/*.cpp) - +file(GLOB SRC_TEST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/tests/*.cpp) + find_package(Threads) set(libraries ${libraries} "stdc++fs" ${CMAKE_THREAD_LIBS_INIT}) find_package(OpenMP) if (OPENMP_FOUND) -set(libraries ${libraries} ${OpenMP_CXX_LIBRARIES}) -set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") -set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") -set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") + set(libraries ${libraries} ${OpenMP_CXX_LIBRARIES}) + set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") endif() -add_executable(matching_distance "src/main.cpp" ${MD_HEADERS} ${BT_HEADERS} ) -target_link_libraries(matching_distance PUBLIC ${libraries}) +add_executable(matching_dist "example/matching_dist.cpp" ${MD_HEADERS} ${BT_HEADERS} ) +target_link_libraries(matching_dist PUBLIC ${libraries}) add_executable(matching_distance_test ${SRC_TEST_FILES} ${BT_HEADERS} ${MD_HEADERS}) target_link_libraries(matching_distance_test PUBLIC ${libraries}) - -add_executable(test_generator "src/test_generator.cpp" - ${MD_HEADERS} - ${BT_HEADERS}) -target_link_libraries(test_generator PUBLIC ${libraries}) - -#add_executable(matching_distance "include/main.cpp" "src/box.cpp" "src/common_util.cpp" "src/line.cpp" "src/persistence_module.hpp" ${BT_HEADERS} ${MD_HEADERS}) diff --git a/matching/README b/matching/README deleted file mode 100644 index 7dc1874..0000000 --- a/matching/README +++ /dev/null @@ -1,5 +0,0 @@ -Matching distance between bifiltrations. - -Currently supports only 1-critical bi-filtrations -in PHAT-like format: boundary matrix + critical values -of a simplex in each row diff --git a/matching/README.md b/matching/README.md new file mode 100644 index 0000000..7dc1874 --- /dev/null +++ b/matching/README.md @@ -0,0 +1,5 @@ +Matching distance between bifiltrations. + +Currently supports only 1-critical bi-filtrations +in PHAT-like format: boundary matrix + critical values +of a simplex in each row diff --git a/matching/example/matching_dist.cpp b/matching/example/matching_dist.cpp new file mode 100644 index 0000000..13e5c6d --- /dev/null +++ b/matching/example/matching_dist.cpp @@ -0,0 +1,396 @@ +#include "common_defs.h" + +#include +#include +#include +#include + +#ifdef MD_EXPERIMENTAL_TIMING +#include +#endif + +#include "opts/opts.h" +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +//#include "persistence_module.h" +#include "bifiltration.h" +#include "box.h" +#include "matching_distance.h" + +using Real = double; + +using namespace md; + +namespace fs = std::experimental::filesystem; + +#ifdef PRINT_HEAT_MAP +void print_heat_map(const md::HeatMaps& hms, std::string fname, const CalculationParams& params) +{ + spd::debug("Entered print_heat_map"); + std::set mu_vals, lambda_vals; + auto hm_iter = hms.end(); + --hm_iter; + int max_level = hm_iter->first; + + int level_cardinality = 4; + for(int i = 0; i < params.initialization_depth; ++i) { + level_cardinality *= 4; + } + for(int i = params.initialization_depth + 1; i <= max_level; ++i) { + spd::debug("hms.at({}).size = {}, must be {}", i, hms.at(i).size(), level_cardinality); + assert(static_cast(hms.at(i).size()) == level_cardinality); + level_cardinality *= 4; + } + + std::map, Real> hm_x_flat, hm_x_steep, hm_y_flat, hm_y_steep; + + for(const auto& dual_point_value_pair : hms.at(max_level)) { + const DualPoint& k = dual_point_value_pair.first; + spd::debug("HM DP: {}", k); + mu_vals.insert(k.mu()); + lambda_vals.insert(k.lambda()); + } + + std::vector lambda_vals_vec(lambda_vals.begin(), lambda_vals.end()); + std::vector mu_vals_vec(mu_vals.begin(), mu_vals.end()); + + std::ofstream ofs {fname}; + if (not ofs.good()) { + std::cerr << "Cannot write heat map to file " << fname << std::endl; + throw std::runtime_error("Cannot open file for writing heat map"); + } + + std::vector> heatmap_to_print(2 * mu_vals_vec.size(), + std::vector(2 * lambda_vals_vec.size(), 0.0)); + + for(auto axis_type : {AxisType::x_type, AxisType::y_type}) { + bool is_x_type = axis_type == AxisType::x_type; + for(auto angle_type : {AngleType::flat, AngleType::steep}) { + bool is_flat = angle_type == AngleType::flat; + + int mu_idx_begin, mu_idx_end; + + if (is_x_type) { + mu_idx_begin = mu_vals_vec.size() - 1; + mu_idx_end = -1; + } else { + mu_idx_begin = 0; + mu_idx_end = mu_vals_vec.size(); + } + + int lambda_idx_begin, lambda_idx_end; + + if (is_flat) { + lambda_idx_begin = 0; + lambda_idx_end = lambda_vals_vec.size(); + } else { + lambda_idx_begin = lambda_vals_vec.size() - 1; + lambda_idx_end = -1; + } + + int mu_idx_final = is_x_type ? 0 : mu_vals_vec.size(); + + for(int mu_idx = mu_idx_begin; mu_idx != mu_idx_end; (mu_idx_begin < mu_idx_end) ? mu_idx++ : mu_idx--) { + Real mu = mu_vals_vec.at(mu_idx); + + if (mu == 0.0 and axis_type == AxisType::x_type) + continue; + + int lambda_idx_final = is_flat ? 0 : lambda_vals_vec.size(); + + for(int lambda_idx = lambda_idx_begin; + lambda_idx != lambda_idx_end; + (lambda_idx_begin < lambda_idx_end) ? lambda_idx++ : lambda_idx--) { + + Real lambda = lambda_vals_vec.at(lambda_idx); + + if (lambda == 0.0 and angle_type == AngleType::flat) + continue; + + DualPoint dp(axis_type, angle_type, lambda, mu); + Real dist_value = hms.at(max_level).at(dp); + + heatmap_to_print.at(mu_idx_final).at(lambda_idx_final) = dist_value; + +// fmt::print("HM, dp = {}, mu_idx_final = {}, lambda_idx_final = {}, value = {}\n", dp, mu_idx_final, +// lambda_idx_final, dist_value); + + lambda_idx_final++; + } + mu_idx_final++; + } + } + } + + for(size_t m_idx = 0; m_idx < heatmap_to_print.size(); ++m_idx) { + for(size_t l_idx = 0; l_idx < heatmap_to_print.at(m_idx).size(); ++l_idx) { + ofs << heatmap_to_print.at(m_idx).at(l_idx) << " "; + } + ofs << std::endl; + } + + ofs.close(); + spd::debug("Exit print_heat_map"); +} +#endif + +int main(int argc, char** argv) +{ + spdlog::set_level(spdlog::level::info); + + using opts::Option; + using opts::PosOption; + opts::Options ops; + + CalculationParams params; + + bool help = false; + +#ifdef MD_EXPERIMENTAL_TIMING + bool no_stop_asap = false; +#endif + + +#ifdef PRINT_HEAT_MAP + bool heatmap_only = false; +#endif + + // default values are the best for practical use + // if compiled with MD_EXPERIMENTAL_TIMING, user can supply + // different bounds and traverse strategies + // separated by commas, all combinations will be timed. + // See corresponding >> operators in matching_distance.h + // for possible values. + std::string bounds_list_str = "local_combined"; + std::string traverse_list_str = "BFS"; + + ops >> Option('m', "max-iterations", params.max_depth, "maximal number of iterations (refinements)") + >> Option('e', "max-error", params.delta, "error threshold") + >> Option('d', "dim", params.dim, "dim") + >> Option('i', "initial-depth", params.initialization_depth, "initialization depth") +#ifdef MD_EXPERIMENTAL_TIMING + >> Option("no-stop-asap", no_stop_asap, + "don't stop looping over points, if cell cannot be pruned (asap is on by default)") + >> Option("bounds", bounds_list_str, "bounds to use, separated by ,") + >> Option("traverse", traverse_list_str, "traverse strategy to use, separated by ,") +#endif +#ifdef PRINT_HEAT_MAP + >> Option("heatmap-only", heatmap_only, "only save heatmap (bruteforce)") +#endif + >> Option('h', "help", help, "show help message"); + + std::string fname_a; + std::string fname_b; + + if (!ops.parse(argc, argv) || help || !(ops >> PosOption(fname_a) >> PosOption(fname_b))) { + std::cerr << "Usage: " << argv[0] << " bifiltration-file-1 bifiltration-file-2\n" << ops << std::endl; + return 1; + } + +#ifdef MD_EXPERIMENTAL_TIMING + params.stop_asap = not no_stop_asap; +#else + params.stop_asap = true; +#endif + + auto bounds_list = split_by_delim(bounds_list_str, ','); + auto traverse_list = split_by_delim(traverse_list_str, ','); + + Bifiltration bif_a(fname_a); + Bifiltration bif_b(fname_b); + + bif_a.sanity_check(); + bif_b.sanity_check(); + + spd::debug("Read bifiltrations {} {}", fname_a, fname_b); + + std::vector bound_strategies; + std::vector traverse_strategies; + + for(std::string s : bounds_list) { + bound_strategies.push_back(bs_from_string(s)); + } + + for(std::string s : traverse_list) { + traverse_strategies.push_back(ts_from_string(s)); + } + + +#ifdef MD_EXPERIMENTAL_TIMING + + for(auto bs : bound_strategies) { + for(auto ts : traverse_strategies) { + spd::info("Will test combination {} {}", bs, ts); + } + } + + struct ExperimentResult { + CalculationParams params {CalculationParams()}; + int n_hera_calls {0}; + double total_milliseconds_elapsed {0}; + Real distance {0}; + Real actual_error {std::numeric_limits::max()}; + int actual_max_depth {0}; + + int x_wins {0}; + int y_wins {0}; + int ad_wins {0}; + + int seconds_elapsed() const + { + return static_cast(total_milliseconds_elapsed / 1000); + } + + double savings_ratio_old() const + { + long int max_possible_calls = 0; + long int calls_on_level = 4; + for(int i = 0; i <= actual_max_depth; ++i) { + max_possible_calls += calls_on_level; + calls_on_level *= 4; + } + return static_cast(n_hera_calls) / static_cast(max_possible_calls); + } + + double savings_ratio() const + { + return static_cast(n_hera_calls) / calls_on_actual_max_depth(); + } + + long long int calls_on_actual_max_depth() const + { + long long int result = 1; + for(int i = 0; i < actual_max_depth; ++i) { + result *= 4; + } + return result; + } + + ExperimentResult() { } + + ExperimentResult(CalculationParams p, int nhc, double tme, double d) + : + params(p), n_hera_calls(nhc), total_milliseconds_elapsed(tme), distance(d) { } + }; + + const int n_repetitions = 1; + +#ifdef PRINT_HEAT_MAP + if (heatmap_only) { + bound_strategies.clear(); + bound_strategies.push_back(BoundStrategy::bruteforce); + traverse_strategies.clear(); + traverse_strategies.push_back(TraverseStrategy::breadth_first); + } +#endif + + std::map, ExperimentResult> results; + for(BoundStrategy bound_strategy : bound_strategies) { + for(TraverseStrategy traverse_strategy : traverse_strategies) { + CalculationParams params_experiment; + params_experiment.bound_strategy = bound_strategy; + params_experiment.traverse_strategy = traverse_strategy; + params_experiment.max_depth = params.max_depth; + params_experiment.initialization_depth = params.initialization_depth; + params_experiment.delta = params.delta; + params_experiment.dim = params.dim; + params_experiment.hera_epsilon = params.hera_epsilon; + params_experiment.stop_asap = params.stop_asap; + + if (traverse_strategy == TraverseStrategy::depth_first and bound_strategy == BoundStrategy::bruteforce) + continue; + + // if bruteforce, clamp max iterations number to 7, + // save user-provided max_iters in user_max_iters and restore it later. + // remember: params is passed by reference to return real relative error and heat map + + int user_max_iters = params.max_depth; +#ifdef PRINT_HEAT_MAP + if (bound_strategy == BoundStrategy::bruteforce and not heatmap_only) { + params_experiment.max_depth = std::min(7, params.max_depth); + } +#endif + double total_milliseconds_elapsed = 0; + int total_n_hera_calls = 0; + Real dist; + for(int i = 0; i < n_repetitions; ++i) { + spd::debug("Processing bound_strategy {}, traverse_strategy {}, iteration = {}", bound_strategy, + traverse_strategy, i); + auto t1 = std::chrono::high_resolution_clock().now(); + dist = matching_distance(bif_a, bif_b, params_experiment); + auto t2 = std::chrono::high_resolution_clock().now(); + total_milliseconds_elapsed += std::chrono::duration_cast( + t2 - t1).count(); + total_n_hera_calls += params_experiment.n_hera_calls; + } + + auto key = std::make_tuple(bound_strategy, traverse_strategy); + results[key].params = params_experiment; + results[key].n_hera_calls = total_n_hera_calls / n_repetitions; + results[key].total_milliseconds_elapsed = total_milliseconds_elapsed / n_repetitions; + results[key].distance = dist; + results[key].actual_error = params_experiment.actual_error; + results[key].actual_max_depth = params_experiment.actual_max_depth; + + spd::info( + "Done (bound = {}, traverse = {}), n_hera_calls = {}, time = {} sec, d = {}, error = {}, savings = {}, max_depth = {}", + bound_strategy, traverse_strategy, results[key].n_hera_calls, results[key].seconds_elapsed(), + dist, + params_experiment.actual_error, results[key].savings_ratio(), results[key].actual_max_depth); + + if (bound_strategy == BoundStrategy::bruteforce) { params_experiment.max_depth = user_max_iters; } + +#ifdef PRINT_HEAT_MAP + if (bound_strategy == BoundStrategy::bruteforce) { + fs::path fname_a_path {fname_a.c_str()}; + fs::path fname_b_path {fname_b.c_str()}; + fs::path fname_a_wo = fname_a_path.filename(); + fs::path fname_b_wo = fname_b_path.filename(); + std::string heat_map_fname = fmt::format("{0}_{1}_dim_{2}_weighted_values_xyp.txt", fname_a_wo.string(), + fname_b_wo.string(), params_experiment.dim); + fs::path path_hm = fname_a_path.replace_filename(fs::path(heat_map_fname.c_str())); + spd::debug("Saving heatmap to {}", heat_map_fname); + print_heat_map(params_experiment.heat_maps, path_hm.string(), params); + } +#endif + spd::debug("Finished processing bound_strategy {}", bound_strategy); + } + } + +// std::cout << "File_1;File_2;Boundstrategy;TraverseStrategy;InitalDepth;NHeraCalls;SavingsRatio;Time;Distance;Error;PushStrategy;MaxDepth;CallsOnMaxDepth;Delta;Dimension" << std::endl; + for(auto bs : bound_strategies) { + for(auto ts : traverse_strategies) { + auto key = std::make_tuple(bs, ts); + + fs::path fname_a_path {fname_a.c_str()}; + fs::path fname_b_path {fname_b.c_str()}; + fs::path fname_a_wo = fname_a_path.filename(); + fs::path fname_b_wo = fname_b_path.filename(); + + std::cout << fname_a_wo.string() << ";" << fname_b_wo.string() << ";" << bs << ";" << ts << ";"; + std::cout << results[key].params.initialization_depth << ";"; + std::cout << results[key].n_hera_calls << ";" + << results[key].savings_ratio() << ";" + << results[key].total_milliseconds_elapsed << ";" + << results[key].distance << ";" + << results[key].actual_error << ";" + << "xyp" << ";" + << results[key].actual_max_depth << ";" + << results[key].calls_on_actual_max_depth() << ";" + << params.delta << ";" + << params.dim + << std::endl; + } + } +#else + params.bound_strategy = bound_strategies.back(); + params.traverse_strategy = traverse_strategies.back(); + + spd::debug("Will use {} bound, {} traverse strategy", params.bound_strategy, params.traverse_strategy); + + Real dist = matching_distance(bif_a, bif_b, params); + std::cout << dist << std::endl; +#endif + return 0; +} diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h index 5be34c7..276cc3a 100644 --- a/matching/include/matching_distance.h +++ b/matching/include/matching_distance.h @@ -169,7 +169,7 @@ namespace md { bool print_stats { false }; #ifdef MD_PRINT_HEAT_MAP - HeatMaps heat_maps; + HeatMaps heat_maps; #endif }; diff --git a/matching/src/main.cpp b/matching/src/main.cpp deleted file mode 100644 index 2093457..0000000 --- a/matching/src/main.cpp +++ /dev/null @@ -1,382 +0,0 @@ -#include "common_defs.h" - -#include -#include -#include -#include - -#ifdef EXPERIMENTAL_TIMING -#include -#endif - -#include "opts/opts.h" -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" - -//#include "persistence_module.h" -#include "bifiltration.h" -#include "box.h" -#include "matching_distance.h" - -using Real = double; - -using namespace md; - -namespace fs = std::experimental::filesystem; - -void force_instantiation() -{ - DualBox db; - std::cout << db; -} - -#ifdef PRINT_HEAT_MAP -void print_heat_map(const md::HeatMaps& hms, std::string fname, const CalculationParams& params) -{ - spd::debug("Entered print_heat_map"); - std::set mu_vals, lambda_vals; - auto hm_iter = hms.end(); - --hm_iter; - int max_level = hm_iter->first; - - int level_cardinality = 4; - for(int i = 0; i < params.initialization_depth; ++i) { - level_cardinality *= 4; - } - for(int i = params.initialization_depth + 1; i <= max_level; ++i) { - spd::debug("hms.at({}).size = {}, must be {}", i, hms.at(i).size(), level_cardinality); - assert(static_cast(hms.at(i).size()) == level_cardinality); - level_cardinality *= 4; - } - - std::map, Real> hm_x_flat, hm_x_steep, hm_y_flat, hm_y_steep; - - for(const auto& dual_point_value_pair : hms.at(max_level)) { - const DualPoint& k = dual_point_value_pair.first; - spd::debug("HM DP: {}", k); - mu_vals.insert(k.mu()); - lambda_vals.insert(k.lambda()); - } - - std::vector lambda_vals_vec(lambda_vals.begin(), lambda_vals.end()); - std::vector mu_vals_vec(mu_vals.begin(), mu_vals.end()); - - std::ofstream ofs {fname}; - if (not ofs.good()) { - std::cerr << "Cannot write heat map to file " << fname << std::endl; - throw std::runtime_error("Cannot open file for writing heat map"); - } - - std::vector> heatmap_to_print(2 * mu_vals_vec.size(), - std::vector(2 * lambda_vals_vec.size(), 0.0)); - - for(auto axis_type : {AxisType::x_type, AxisType::y_type}) { - bool is_x_type = axis_type == AxisType::x_type; - for(auto angle_type : {AngleType::flat, AngleType::steep}) { - bool is_flat = angle_type == AngleType::flat; - - int mu_idx_begin, mu_idx_end; - - if (is_x_type) { - mu_idx_begin = mu_vals_vec.size() - 1; - mu_idx_end = -1; - } else { - mu_idx_begin = 0; - mu_idx_end = mu_vals_vec.size(); - } - - int lambda_idx_begin, lambda_idx_end; - - if (is_flat) { - lambda_idx_begin = 0; - lambda_idx_end = lambda_vals_vec.size(); - } else { - lambda_idx_begin = lambda_vals_vec.size() - 1; - lambda_idx_end = -1; - } - - int mu_idx_final = is_x_type ? 0 : mu_vals_vec.size(); - - for(int mu_idx = mu_idx_begin; mu_idx != mu_idx_end; (mu_idx_begin < mu_idx_end) ? mu_idx++ : mu_idx--) { - Real mu = mu_vals_vec.at(mu_idx); - - if (mu == 0.0 and axis_type == AxisType::x_type) - continue; - - int lambda_idx_final = is_flat ? 0 : lambda_vals_vec.size(); - - for(int lambda_idx = lambda_idx_begin; - lambda_idx != lambda_idx_end; - (lambda_idx_begin < lambda_idx_end) ? lambda_idx++ : lambda_idx--) { - - Real lambda = lambda_vals_vec.at(lambda_idx); - - if (lambda == 0.0 and angle_type == AngleType::flat) - continue; - - DualPoint dp(axis_type, angle_type, lambda, mu); - Real dist_value = hms.at(max_level).at(dp); - - heatmap_to_print.at(mu_idx_final).at(lambda_idx_final) = dist_value; - -// fmt::print("HM, dp = {}, mu_idx_final = {}, lambda_idx_final = {}, value = {}\n", dp, mu_idx_final, -// lambda_idx_final, dist_value); - - lambda_idx_final++; - } - mu_idx_final++; - } - } - } - - for(size_t m_idx = 0; m_idx < heatmap_to_print.size(); ++m_idx) { - for(size_t l_idx = 0; l_idx < heatmap_to_print.at(m_idx).size(); ++l_idx) { - ofs << heatmap_to_print.at(m_idx).at(l_idx) << " "; - } - ofs << std::endl; - } - - ofs.close(); - spd::debug("Exit print_heat_map"); -} -#endif - -int main(int argc, char** argv) -{ - spdlog::set_level(spdlog::level::info); - - using opts::Option; - using opts::PosOption; - opts::Options ops; - - bool help = false; - bool no_stop_asap = false; - CalculationParams params; - -#ifdef PRINT_HEAT_MAP - bool heatmap_only = false; -#endif - - std::string bounds_list_str = "local_combined"; - std::string traverse_list_str = "BFS"; - - ops >> Option('m', "max-iterations", params.max_depth, "maximal number of iterations (refinements)") - >> Option('e', "max-error", params.delta, "error threshold") - >> Option('d', "dim", params.dim, "dim") - >> Option('i', "initial-depth", params.initialization_depth, "initialization depth") - >> Option("no-stop-asap", no_stop_asap, - "don't stop looping over points, if cell cannot be pruned (asap is on by default)") - >> Option("bounds", bounds_list_str, "bounds to use, separated by ,") - >> Option("traverse", traverse_list_str, "traverse to use, separated by ,") -#ifdef PRINT_HEAT_MAP - >> Option("heatmap-only", heatmap_only, "only save heatmap (bruteforce)") -#endif - >> Option('h', "help", help, "show help message"); - - std::string fname_a; - std::string fname_b; - - if (!ops.parse(argc, argv) || help || !(ops >> PosOption(fname_a) >> PosOption(fname_b))) { - std::cerr << "Usage: " << argv[0] << "bifiltration-file-1 bifiltration-file-2\n" << ops << std::endl; - return 1; - } - - params.stop_asap = not no_stop_asap; - - auto bounds_list = split_by_delim(bounds_list_str, ','); - auto traverse_list = split_by_delim(traverse_list_str, ','); - - Bifiltration bif_a(fname_a); - Bifiltration bif_b(fname_b); - - bif_a.sanity_check(); - bif_b.sanity_check(); - - spd::debug("Read bifiltrations {} {}", fname_a, fname_b); - - std::vector bound_strategies; - std::vector traverse_strategies; - - for(std::string s : bounds_list) { - bound_strategies.push_back(bs_from_string(s)); - } - - for(std::string s : traverse_list) { - traverse_strategies.push_back(ts_from_string(s)); - } - - -#ifdef EXPERIMENTAL_TIMING - - for(auto bs : bound_strategies) { - for(auto ts : traverse_strategies) { - spd::info("Will test combination {} {}", bs, ts); - } - } - - struct ExperimentResult { - CalculationParams params {CalculationParams()}; - int n_hera_calls {0}; - double total_milliseconds_elapsed {0}; - Real distance {0}; - Real actual_error {std::numeric_limits::max()}; - int actual_max_depth {0}; - - int x_wins {0}; - int y_wins {0}; - int ad_wins {0}; - - int seconds_elapsed() const - { - return static_cast(total_milliseconds_elapsed / 1000); - } - - double savings_ratio_old() const - { - long int max_possible_calls = 0; - long int calls_on_level = 4; - for(int i = 0; i <= actual_max_depth; ++i) { - max_possible_calls += calls_on_level; - calls_on_level *= 4; - } - return static_cast(n_hera_calls) / static_cast(max_possible_calls); - } - - double savings_ratio() const - { - return static_cast(n_hera_calls) / calls_on_actual_max_depth(); - } - - long long int calls_on_actual_max_depth() const - { - long long int result = 1; - for(int i = 0; i < actual_max_depth; ++i) { - result *= 4; - } - return result; - } - - ExperimentResult() { } - - ExperimentResult(CalculationParams p, int nhc, double tme, double d) - : - params(p), n_hera_calls(nhc), total_milliseconds_elapsed(tme), distance(d) { } - }; - - const int n_repetitions = 1; - - if (heatmap_only) { - bound_strategies.clear(); - bound_strategies.push_back(BoundStrategy::bruteforce); - traverse_strategies.clear(); - traverse_strategies.push_back(TraverseStrategy::breadth_first); - } - - std::map, ExperimentResult> results; - for(BoundStrategy bound_strategy : bound_strategies) { - for(TraverseStrategy traverse_strategy : traverse_strategies) { - CalculationParams params_experiment; - params_experiment.bound_strategy = bound_strategy; - params_experiment.traverse_strategy = traverse_strategy; - params_experiment.max_depth = params.max_depth; - params_experiment.initialization_depth = params.initialization_depth; - params_experiment.delta = params.delta; - params_experiment.dim = params.dim; - params_experiment.hera_epsilon = params.hera_epsilon; - params_experiment.stop_asap = params.stop_asap; - - if (traverse_strategy == TraverseStrategy::depth_first and bound_strategy == BoundStrategy::bruteforce) - continue; - - // if bruteforce, clamp max iterations number to 7, - // save user-provided max_iters in user_max_iters and restore it later. - // remember: params is passed by reference to return real relative error and heat map - - int user_max_iters = params.max_depth; - if (bound_strategy == BoundStrategy::bruteforce and not heatmap_only) { - params_experiment.max_depth = std::min(7, params.max_depth); - } - double total_milliseconds_elapsed = 0; - int total_n_hera_calls = 0; - Real dist; - for(int i = 0; i < n_repetitions; ++i) { - spd::debug("Processing bound_strategy {}, traverse_strategy {}, iteration = {}", bound_strategy, - traverse_strategy, i); - auto t1 = std::chrono::high_resolution_clock().now(); - dist = matching_distance(bif_a, bif_b, params_experiment); - auto t2 = std::chrono::high_resolution_clock().now(); - total_milliseconds_elapsed += std::chrono::duration_cast( - t2 - t1).count(); - total_n_hera_calls += params_experiment.n_hera_calls; - } - - auto key = std::make_tuple(bound_strategy, traverse_strategy); - results[key].params = params_experiment; - results[key].n_hera_calls = total_n_hera_calls / n_repetitions; - results[key].total_milliseconds_elapsed = total_milliseconds_elapsed / n_repetitions; - results[key].distance = dist; - results[key].actual_error = params_experiment.actual_error; - results[key].actual_max_depth = params_experiment.actual_max_depth; - - spd::info( - "Done (bound = {}, traverse = {}), n_hera_calls = {}, time = {} sec, d = {}, error = {}, savings = {}, max_depth = {}", - bound_strategy, traverse_strategy, results[key].n_hera_calls, results[key].seconds_elapsed(), - dist, - params_experiment.actual_error, results[key].savings_ratio(), results[key].actual_max_depth); - - if (bound_strategy == BoundStrategy::bruteforce) { params_experiment.max_depth = user_max_iters; } - -#ifdef PRINT_HEAT_MAP - if (bound_strategy == BoundStrategy::bruteforce) { - fs::path fname_a_path {fname_a.c_str()}; - fs::path fname_b_path {fname_b.c_str()}; - fs::path fname_a_wo = fname_a_path.filename(); - fs::path fname_b_wo = fname_b_path.filename(); - std::string heat_map_fname = fmt::format("{0}_{1}_dim_{2}_weighted_values_xyp.txt", fname_a_wo.string(), - fname_b_wo.string(), params_experiment.dim); - fs::path path_hm = fname_a_path.replace_filename(fs::path(heat_map_fname.c_str())); - spd::debug("Saving heatmap to {}", heat_map_fname); - print_heat_map(params_experiment.heat_maps, path_hm.string(), params); - } -#endif - spd::debug("Finished processing bound_strategy {}", bound_strategy); - } - } - -// std::cout << "File_1;File_2;Boundstrategy;TraverseStrategy;InitalDepth;NHeraCalls;SavingsRatio;Time;Distance;Error;PushStrategy;MaxDepth;CallsOnMaxDepth;Delta;Dimension" << std::endl; - for(auto bs : bound_strategies) { - for(auto ts : traverse_strategies) { - auto key = std::make_tuple(bs, ts); - - fs::path fname_a_path {fname_a.c_str()}; - fs::path fname_b_path {fname_b.c_str()}; - fs::path fname_a_wo = fname_a_path.filename(); - fs::path fname_b_wo = fname_b_path.filename(); - - std::cout << fname_a_wo.string() << ";" << fname_b_wo.string() << ";" << bs << ";" << ts << ";"; - std::cout << results[key].params.initialization_depth << ";"; - std::cout << results[key].n_hera_calls << ";" - << results[key].savings_ratio() << ";" - << results[key].total_milliseconds_elapsed << ";" - << results[key].distance << ";" - << results[key].actual_error << ";" - << "xyp" << ";" - << results[key].actual_max_depth << ";" - << results[key].calls_on_actual_max_depth() << ";" - << params.delta << ";" - << params.dim - << std::endl; - } - } -#else - params.bound_strategy = bound_strategies.back(); - params.traverse_strategy = traverse_strategies.back(); - - spd::debug("Will use {} bound, {} traverse strategy", params.bound_strategy, params.traverse_strategy); - - Real dist = matching_distance(bif_a, bif_b, params); - std::cout << dist << std::endl; -#endif - force_instantiation(); - return 0; -} diff --git a/matching/src/test_generator.cpp b/matching/src/test_generator.cpp deleted file mode 100644 index a2f0625..0000000 --- a/matching/src/test_generator.cpp +++ /dev/null @@ -1,211 +0,0 @@ -#include -#include -#include -#include -#include - -#include "opts/opts.h" -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" - -#include "common_util.h" -#include "bifiltration.h" - -using Real = double; -using Index = md::Index; -using Point = md::Point; -using Bifiltration = md::Bifiltration; -using Column = md::Column; -using Simplex = md::Simplex; - -int g_max_coord = 100; - -using ASimplex = md::AbstractSimplex; - -using ASimplexToBirthMap = std::map; - -namespace spd = spdlog; - -// random generator is global -std::random_device rd; -std::mt19937_64 gen(rd()); - -//std::mt19937_64 gen(42); - -Point get_random_position(int max_coord) -{ - assert(max_coord > 0); - std::uniform_int_distribution distr(0, max_coord); - return Point(distr(gen), distr(gen)); - -} - -Point get_random_position_less_than(Point ub) -{ - std::uniform_int_distribution distr_x(0, ub.x); - std::uniform_int_distribution distr_y(0, ub.y); - return Point(distr_x(gen), distr_y(gen)); -} - -Point get_random_position_greater_than(Point lb) -{ - std::uniform_int_distribution distr_x(lb.x, g_max_coord); - std::uniform_int_distribution distr_y(lb.y, g_max_coord); - return Point(distr_x(gen), distr_y(gen)); -} - -// non-proper faces (empty and simplex itself) are also faces -bool is_face(const ASimplex& face_candidate, const ASimplex& coface_candidate) -{ - return std::includes(coface_candidate.begin(), coface_candidate.end(), face_candidate.begin(), - face_candidate.end()); -} - -bool is_top_simplex(const ASimplex& candidate_simplex, const std::vector& current_top_simplices) -{ - return std::none_of(current_top_simplices.begin(), current_top_simplices.end(), - [&candidate_simplex](const ASimplex& ts) { return is_face(candidate_simplex, ts); }); -} - -void add_if_top(const ASimplex& candidate_simplex, std::vector& current_top_simplices) -{ - // check that candidate simplex is not face of someone in current_top_simplices - if (!is_top_simplex(candidate_simplex, current_top_simplices)) - return; - - spd::debug("candidate_simplex is top, will be added to top_simplices"); - // remove s from currrent_top_simplices, if s is face of candidate_simplex - current_top_simplices.erase(std::remove_if(current_top_simplices.begin(), current_top_simplices.end(), - [candidate_simplex](const ASimplex& s) { return is_face(s, candidate_simplex); }), - current_top_simplices.end()); - - current_top_simplices.push_back(candidate_simplex); -} - -ASimplex get_random_simplex(int n_vertices, int dim) -{ - std::vector all_vertices(n_vertices, 0); - // fill in with 0..n-1 - std::iota(all_vertices.begin(), all_vertices.end(), 0); - std::shuffle(all_vertices.begin(), all_vertices.end(), gen); - return ASimplex(all_vertices.begin(), all_vertices.begin() + dim + 1, true); -} - -void generate_positions(const ASimplex& s, ASimplexToBirthMap& simplex_to_birth, Point upper_bound) -{ - auto pos = get_random_position_less_than(upper_bound); - auto curr_pos_iter = simplex_to_birth.find(s); - if (curr_pos_iter != simplex_to_birth.end()) - pos = md::greatest_lower_bound(pos, curr_pos_iter->second); - simplex_to_birth[s] = pos; - for(const ASimplex& facet : s.facets()) { - generate_positions(facet, simplex_to_birth, pos); - } -} - -Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_simplices) -{ - ASimplexToBirthMap simplex_to_birth; - - // generate vertices - for(int i = 0; i < n_vertices; ++i) { - Point vertex_birth = get_random_position(g_max_coord / 10); - ASimplex vertex; - vertex.push_back(i); - simplex_to_birth[vertex] = vertex_birth; - } - - std::vector top_simplices; - // generate top simplices - while((int)top_simplices.size() < n_top_simplices) { - std::uniform_int_distribution dimension_distr(1, max_dim); - int dim = dimension_distr(gen); - auto candidate_simplex = get_random_simplex(n_vertices, dim); - spd::debug("candidate_simplex = {}", candidate_simplex); - add_if_top(candidate_simplex, top_simplices); - } - - Point upper_bound{static_cast(g_max_coord), static_cast(g_max_coord)}; - for(const auto& top_simplex : top_simplices) { - generate_positions(top_simplex, simplex_to_birth, upper_bound); - } - - std::vector> simplex_birth_pairs{simplex_to_birth.begin(), simplex_to_birth.end()}; - std::vector boundaries{simplex_to_birth.size(), Column()}; - -// assign ids and save boundaries - int id = 0; - - for(int dim = 0; dim <= max_dim; ++dim) { - for(int i = 0; i < (int) simplex_birth_pairs.size(); ++i) { - ASimplex& simplex = simplex_birth_pairs[i].first; - if (simplex.dim() == dim) { - simplex.id = id++; - Column bdry; - for(auto& facet : simplex.facets()) { - auto facet_iter = std::find_if(simplex_birth_pairs.begin(), simplex_birth_pairs.end(), - [facet](const std::pair& sbp) { return facet == sbp.first; }); - assert(facet_iter != simplex_birth_pairs.end()); - assert(facet_iter->first.id >= 0); - bdry.push_back(facet_iter->first.id); - } - std::sort(bdry.begin(), bdry.end()); - boundaries[i] = bdry; - } - } - } - -// create vector of Simplex-es - std::vector simplices; - for(int i = 0; i < (int) simplex_birth_pairs.size(); ++i) { - int id = simplex_birth_pairs[i].first.id; - int dim = simplex_birth_pairs[i].first.dim(); - Point birth = simplex_birth_pairs[i].second; - Column bdry = boundaries[i]; - simplices.emplace_back(id, birth, dim, bdry); - } - -// sort by id - std::sort(simplices.begin(), simplices.end(), - [](const Simplex& s1, const Simplex& s2) { return s1.id() < s2.id(); }); - for(int i = 0; i < (int)simplices.size(); ++i) { - assert(simplices[i].id() == i); - assert(i == 0 || simplices[i].dim() >= simplices[i - 1].dim()); - } - - return Bifiltration(simplices.begin(), simplices.end()); -} - -int main(int argc, char** argv) -{ - spd::set_level(spd::level::info); - int n_vertices; - int max_dim; - int n_top_simplices; - std::string fname; - - using opts::Option; - using opts::PosOption; - opts::Options ops; - - bool help = false; - - ops >> Option('v', "n-vertices", n_vertices, "number of vertices") - >> Option('d', "max-dim", max_dim, "maximal dim") - >> Option('m', "max-coord", g_max_coord, "maximal coordinate") - >> Option('t', "n-top-simplices", n_top_simplices, "number of top simplices") - >> Option('h', "help", help, "show help message"); - - if (!ops.parse(argc, argv) || help || !(ops >> PosOption(fname))) { - std::cerr << "Usage: " << argv[0] << "\n" << ops << std::endl; - return 1; - } - - - auto bif1 = get_random_bifiltration(n_vertices, max_dim, n_top_simplices); - std::cout << "Generated bifiltration." << std::endl; - bif1.save(fname, md::BifiltrationFormat::phat_like); - std::cout << "Saved to file " << fname << std::endl; - return 0; -} - diff --git a/matching/src/tests/prism_1.bif b/matching/src/tests/prism_1.bif deleted file mode 100644 index b37e807..0000000 --- a/matching/src/tests/prism_1.bif +++ /dev/null @@ -1,28 +0,0 @@ -bifiltration_phat_like -26 -0 0 0 -0 0 0 -0 0 0 -0 0 0 -0 0 0 -0 0 0 -1 1 0 1 0 -1 1 0 2 1 -1 0 0 3 1 -1 0 1 5 4 -1 0 0 4 1 -1 0 0 3 0 -1 0 0 5 0 -1 0 0 4 2 -1 0 0 5 2 -1 0 1 4 3 -1 1 0 2 0 -1 0 1 5 3 -2 1 1 13 10 7 -2 1 1 9 14 13 -2 1 1 8 11 6 -2 1 1 15 10 8 -2 1 1 14 12 16 -2 1 1 17 12 11 -2 4 0 7 16 6 -2 0 4 9 17 15 diff --git a/matching/src/tests/prism_2.bif b/matching/src/tests/prism_2.bif deleted file mode 100644 index 49885e6..0000000 --- a/matching/src/tests/prism_2.bif +++ /dev/null @@ -1,29 +0,0 @@ -bifiltration_phat_like -27 -0 0 0 -0 0 0 -0 0 0 -0 0 0 -0 0 0 -0 0 0 -1 0 0 2 1 -1 0 0 1 0 -1 0 0 5 3 -1 0 0 3 1 -1 0 0 5 4 -1 0 0 2 0 -1 0 0 4 1 -1 0 0 3 2 -1 0 0 5 2 -1 0 0 4 3 -1 0 0 4 2 -1 0 0 3 0 -2 0 0 16 12 6 -2 0 0 10 14 16 -2 0 0 9 17 7 -2 0 0 15 12 9 -2 0 0 13 17 11 -2 0 0 8 14 13 -2 4 0 6 11 7 -2 0 4 10 8 15 -2 3 3 15 16 13 diff --git a/matching/src/tests/test_bifiltration.cpp b/matching/src/tests/test_bifiltration.cpp deleted file mode 100644 index 742dab8..0000000 --- a/matching/src/tests/test_bifiltration.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include "catch/catch.hpp" - -#include -#include - -#include "common_util.h" -#include "box.h" -#include "bifiltration.h" - -using namespace md; - -//TEST_CASE("Small check", "[bifiltration][dim2]") -//{ -// Bifiltration bif("/home/narn/code/matching_distance/code/src/tests/test_bifiltration_full_triangle_phat_like.txt", BifiltrationFormat::phat_like); -// auto simplices = bif.simplices(); -// bif.sanity_check(); -// -// REQUIRE( simplices.size() == 7 ); -// -// REQUIRE( simplices[0].dim() == 0 ); -// REQUIRE( simplices[1].dim() == 0 ); -// REQUIRE( simplices[2].dim() == 0 ); -// REQUIRE( simplices[3].dim() == 1 ); -// REQUIRE( simplices[4].dim() == 1 ); -// REQUIRE( simplices[5].dim() == 1 ); -// REQUIRE( simplices[6].dim() == 2); -// -// REQUIRE( simplices[0].position() == Point(0, 0)); -// REQUIRE( simplices[1].position() == Point(0, 0)); -// REQUIRE( simplices[2].position() == Point(0, 0)); -// REQUIRE( simplices[3].position() == Point(3, 1)); -// REQUIRE( simplices[6].position() == Point(30, 40)); -// -// Line line_1(Line::pi / 2.0, 0.0); -// auto dgm = bif.slice_diagram(line_1); -//} diff --git a/matching/src/tests/test_bifiltration_1.txt b/matching/src/tests/test_bifiltration_1.txt deleted file mode 120000 index ddd23e9..0000000 --- a/matching/src/tests/test_bifiltration_1.txt +++ /dev/null @@ -1 +0,0 @@ -../../data/test_bifiltration_1.txt \ No newline at end of file diff --git a/matching/src/tests/test_bifiltration_full_triangle_rene.txt b/matching/src/tests/test_bifiltration_full_triangle_rene.txt deleted file mode 120000 index 47f49fd..0000000 --- a/matching/src/tests/test_bifiltration_full_triangle_rene.txt +++ /dev/null @@ -1 +0,0 @@ -../../data/test_bifiltration_full_triangle_rene.txt \ No newline at end of file diff --git a/matching/src/tests/test_common.cpp b/matching/src/tests/test_common.cpp deleted file mode 100644 index 9079a56..0000000 --- a/matching/src/tests/test_common.cpp +++ /dev/null @@ -1,156 +0,0 @@ -#include "catch/catch.hpp" - -#include -#include -#include - -#include "common_util.h" -#include "simplex.h" -#include "matching_distance.h" - -//using namespace md; -using Real = double; -using Point = md::Point; -using Bifiltration = md::Bifiltration; -using BifiltrationProxy = md::BifiltrationProxy; -using CalculationParams = md::CalculationParams; -using CellWithValue = md::CellWithValue; -using DualPoint = md::DualPoint; -using DualBox = md::DualBox; -using Simplex = md::Simplex; -using AbstractSimplex = md::AbstractSimplex; -using BoundStrategy = md::BoundStrategy; -using TraverseStrategy = md::TraverseStrategy; -using AxisType = md::AxisType; -using AngleType = md::AngleType; -using ValuePoint = md::ValuePoint; -using Column = md::Column; - - -TEST_CASE("AbstractSimplex", "[abstract_simplex]") -{ - AbstractSimplex as; - REQUIRE(as.dim() == -1); - - as.push_back(1); - REQUIRE(as.dim() == 0); - REQUIRE(as.facets().size() == 0); - - as.push_back(0); - REQUIRE(as.dim() == 1); - REQUIRE(as.facets().size() == 2); - REQUIRE(as.facets()[0].dim() == 0); - REQUIRE(as.facets()[1].dim() == 0); - -} - -TEST_CASE("Vertical line", "[vertical_line]") -{ - // line x = 1 - DualPoint l_vertical(AxisType::x_type, AngleType::steep, 0, 1); - - REQUIRE(l_vertical.is_vertical()); - REQUIRE(l_vertical.is_steep()); - - Point p_1(0.5, 0.5); - Point p_2(1.5, 0.5); - Point p_3(1.5, 1.5); - Point p_4(0.5, 1.5); - Point p_5(1, 10); - - REQUIRE(l_vertical.x_from_y(10) == 1); - REQUIRE(l_vertical.x_from_y(-10) == 1); - REQUIRE(l_vertical.x_from_y(0) == 1); - - REQUIRE(not l_vertical.contains(p_1)); - REQUIRE(not l_vertical.contains(p_2)); - REQUIRE(not l_vertical.contains(p_3)); - REQUIRE(not l_vertical.contains(p_4)); - REQUIRE(l_vertical.contains(p_5)); - - REQUIRE(l_vertical.goes_below(p_1)); - REQUIRE(not l_vertical.goes_below(p_2)); - REQUIRE(not l_vertical.goes_below(p_3)); - REQUIRE(l_vertical.goes_below(p_4)); - - REQUIRE(not l_vertical.goes_above(p_1)); - REQUIRE(l_vertical.goes_above(p_2)); - REQUIRE(l_vertical.goes_above(p_3)); - REQUIRE(not l_vertical.goes_above(p_4)); - -} - -TEST_CASE("Horizontal line", "[horizontal_line]") -{ - // line y = 1 - DualPoint l_horizontal(AxisType::y_type, AngleType::flat, 0, 1); - - REQUIRE(l_horizontal.is_horizontal()); - REQUIRE(l_horizontal.is_flat()); - REQUIRE(l_horizontal.y_slope() == 0); - REQUIRE(l_horizontal.y_intercept() == 1); - - Point p_1(0.5, 0.5); - Point p_2(1.5, 0.5); - Point p_3(1.5, 1.5); - Point p_4(0.5, 1.5); - Point p_5(2, 1); - - REQUIRE((not l_horizontal.contains(p_1) and - not l_horizontal.contains(p_2) and - not l_horizontal.contains(p_3) and - not l_horizontal.contains(p_4) and - l_horizontal.contains(p_5))); - - REQUIRE(not l_horizontal.goes_below(p_1)); - REQUIRE(not l_horizontal.goes_below(p_2)); - REQUIRE(l_horizontal.goes_below(p_3)); - REQUIRE(l_horizontal.goes_below(p_4)); - REQUIRE(l_horizontal.goes_below(p_5)); - - REQUIRE(l_horizontal.goes_above(p_1)); - REQUIRE(l_horizontal.goes_above(p_2)); - REQUIRE(not l_horizontal.goes_above(p_3)); - REQUIRE(not l_horizontal.goes_above(p_4)); - REQUIRE(l_horizontal.goes_above(p_5)); -} - -TEST_CASE("Flat Line with positive slope", "[flat_line]") -{ - // line y = x / 2 + 1 - DualPoint l_flat(AxisType::y_type, AngleType::flat, 0.5, 1); - - REQUIRE(not l_flat.is_horizontal()); - REQUIRE(l_flat.is_flat()); - REQUIRE(l_flat.y_slope() == 0.5); - REQUIRE(l_flat.y_intercept() == 1); - - REQUIRE(l_flat.y_from_x(0) == 1); - REQUIRE(l_flat.y_from_x(1) == 1.5); - REQUIRE(l_flat.y_from_x(2) == 2); - - Point p_1(3, 2); - Point p_2(-2, 0.01); - Point p_3(0, 1.25); - Point p_4(6, 4.5); - Point p_5(2, 2); - - REQUIRE((not l_flat.contains(p_1) and - not l_flat.contains(p_2) and - not l_flat.contains(p_3) and - not l_flat.contains(p_4) and - l_flat.contains(p_5))); - - REQUIRE(not l_flat.goes_below(p_1)); - REQUIRE(l_flat.goes_below(p_2)); - REQUIRE(l_flat.goes_below(p_3)); - REQUIRE(l_flat.goes_below(p_4)); - REQUIRE(l_flat.goes_below(p_5)); - - REQUIRE(l_flat.goes_above(p_1)); - REQUIRE(not l_flat.goes_above(p_2)); - REQUIRE(not l_flat.goes_above(p_3)); - REQUIRE(not l_flat.goes_above(p_4)); - REQUIRE(l_flat.goes_above(p_5)); - -} diff --git a/matching/src/tests/test_list.txt b/matching/src/tests/test_list.txt deleted file mode 100644 index 1984606..0000000 --- a/matching/src/tests/test_list.txt +++ /dev/null @@ -1 +0,0 @@ -prism_lesnick_1.bif prism_lesnick_2.bif 1.0 diff --git a/matching/src/tests/test_matching_distance.cpp b/matching/src/tests/test_matching_distance.cpp deleted file mode 100644 index 82da530..0000000 --- a/matching/src/tests/test_matching_distance.cpp +++ /dev/null @@ -1,167 +0,0 @@ -#include "catch/catch.hpp" - -#include -#include -#include - -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" - -#include "common_util.h" -#include "simplex.h" -#include "matching_distance.h" - -using Real = double; -using Point = md::Point; -using Bifiltration = md::Bifiltration; -using BifiltrationProxy = md::BifiltrationProxy; -using CalculationParams = md::CalculationParams; -using CellWithValue = md::CellWithValue; -using DualPoint = md::DualPoint; -using DualBox = md::DualBox; -using Simplex = md::Simplex; -using AbstractSimplex = md::AbstractSimplex; -using BoundStrategy = md::BoundStrategy; -using TraverseStrategy = md::TraverseStrategy; -using AxisType = md::AxisType; -using AngleType = md::AngleType; -using ValuePoint = md::ValuePoint; -using Column = md::Column; - -using md::k_corner_vps; - -namespace spd = spdlog; - -TEST_CASE("Different bounds", "[bounds]") -{ - std::vector simplices; - std::vector points; - - Real max_x = 10; - Real max_y = 20; - - int simplex_id = 0; - for(int i = 0; i <= max_x; ++i) { - for(int j = 0; j <= max_y; ++j) { - Point p(i, j); - simplices.emplace_back(simplex_id++, p, 0, Column()); - points.push_back(p); - } - } - - Bifiltration bif_a(simplices.begin(), simplices.end()); - Bifiltration bif_b(simplices.begin(), simplices.end()); - - CalculationParams params; - params.initialization_depth = 2; - - BifiltrationProxy bifp_a(bif_a, params.dim); - BifiltrationProxy bifp_b(bif_b, params.dim); - - md::DistanceCalculator calc(bifp_a, bifp_b, params); - -// REQUIRE(calc.max_x_ == Approx(max_x)); -// REQUIRE(calc.max_y_ == Approx(max_y)); - - std::vector boxes; - - for(CellWithValue c : calc.get_refined_grid(5, false, false)) { - boxes.push_back(c.dual_box()); - } - - // fill in boxes and points - - for(DualBox db : boxes) { - Real local_bound = calc.get_local_dual_bound(db); - Real local_bound_refined = calc.get_local_refined_bound(db); - REQUIRE(local_bound >= local_bound_refined); - for(Point p : points) { - for(ValuePoint vp_a : k_corner_vps) { - CellWithValue dual_cell(db, 1); - DualPoint corner_a = dual_cell.value_point(vp_a); - Real wp_a = corner_a.weighted_push(p); - dual_cell.set_value_at(vp_a, wp_a); - Real point_bound = calc.get_max_displacement_single_point(dual_cell, vp_a, p); - for(ValuePoint vp_b : k_corner_vps) { - if (vp_b <= vp_a) - continue; - DualPoint corner_b = dual_cell.value_point(vp_b); - Real wp_b = corner_b.weighted_push(p); - Real diff = fabs(wp_a - wp_b); - if (not(point_bound <= Approx(local_bound_refined))) { - std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound - << ", refined local = " << local_bound_refined << std::endl; - spd::set_level(spd::level::debug); - calc.get_max_displacement_single_point(dual_cell, vp_a, p); - calc.get_local_refined_bound(db); - spd::set_level(spd::level::info); - } - - REQUIRE(point_bound <= Approx(local_bound_refined)); - REQUIRE(diff <= Approx(point_bound)); - REQUIRE(diff <= Approx(local_bound_refined)); - } - - for(DualPoint l_random : db.random_points(100)) { - Real wp_random = l_random.weighted_push(p); - Real diff = fabs(wp_a - wp_random); - if (not(diff <= Approx(point_bound))) { - if (db.critical_points(p).size() > 4) { - std::cerr << "ERROR interesting case" << std::endl; - } else { - std::cerr << "ERROR boring case" << std::endl; - } - spd::set_level(spd::level::debug); - l_random.weighted_push(p); - spd::set_level(spd::level::info); - std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound - << ", refined local = " << local_bound_refined; - std::cerr << ", random_dual_point = " << l_random << ", wp_random = " << wp_random - << ", diff = " << diff << std::endl; - } - REQUIRE(diff <= Approx(point_bound)); - } - } - } - } -} - -TEST_CASE("Bifiltrations from file", "[matching_distance][small_example][lesnick]") -{ - std::string fname_a, fname_b; - - fname_a = "../src/tests/prism_1.bif"; - fname_b = "../src/tests/prism_2.bif"; - - Bifiltration bif_a(fname_a); - Bifiltration bif_b(fname_b); - - CalculationParams params; - - std::vector bound_strategies {BoundStrategy::local_combined, - BoundStrategy::local_dual_bound_refined}; - - std::vector traverse_strategies {TraverseStrategy::breadth_first, TraverseStrategy::depth_first}; - - std::vector scaling_factors {10, 1.0}; - - for(auto bs : bound_strategies) { - for(auto ts : traverse_strategies) { - for(double lambda : scaling_factors) { - Bifiltration bif_a_copy(bif_a); - Bifiltration bif_b_copy(bif_b); - bif_a_copy.scale(lambda); - bif_b_copy.scale(lambda); - params.bound_strategy = bs; - params.traverse_strategy = ts; - params.max_depth = 7; - params.delta = 0.1; - params.dim = 1; - Real answer = matching_distance(bif_a_copy, bif_b_copy, params); - Real correct_answer = lambda * 1.0; - REQUIRE(fabs(answer - correct_answer) / correct_answer < params.delta); - } - } - } -} - diff --git a/matching/src/tests/tests_main.cpp b/matching/src/tests/tests_main.cpp deleted file mode 100644 index 1c77b13..0000000 --- a/matching/src/tests/tests_main.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#define CATCH_CONFIG_MAIN -#include "catch/catch.hpp" diff --git a/matching/tests/prism_1.bif b/matching/tests/prism_1.bif new file mode 100644 index 0000000..b37e807 --- /dev/null +++ b/matching/tests/prism_1.bif @@ -0,0 +1,28 @@ +bifiltration_phat_like +26 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +1 1 0 1 0 +1 1 0 2 1 +1 0 0 3 1 +1 0 1 5 4 +1 0 0 4 1 +1 0 0 3 0 +1 0 0 5 0 +1 0 0 4 2 +1 0 0 5 2 +1 0 1 4 3 +1 1 0 2 0 +1 0 1 5 3 +2 1 1 13 10 7 +2 1 1 9 14 13 +2 1 1 8 11 6 +2 1 1 15 10 8 +2 1 1 14 12 16 +2 1 1 17 12 11 +2 4 0 7 16 6 +2 0 4 9 17 15 diff --git a/matching/tests/prism_2.bif b/matching/tests/prism_2.bif new file mode 100644 index 0000000..49885e6 --- /dev/null +++ b/matching/tests/prism_2.bif @@ -0,0 +1,29 @@ +bifiltration_phat_like +27 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +1 0 0 2 1 +1 0 0 1 0 +1 0 0 5 3 +1 0 0 3 1 +1 0 0 5 4 +1 0 0 2 0 +1 0 0 4 1 +1 0 0 3 2 +1 0 0 5 2 +1 0 0 4 3 +1 0 0 4 2 +1 0 0 3 0 +2 0 0 16 12 6 +2 0 0 10 14 16 +2 0 0 9 17 7 +2 0 0 15 12 9 +2 0 0 13 17 11 +2 0 0 8 14 13 +2 4 0 6 11 7 +2 0 4 10 8 15 +2 3 3 15 16 13 diff --git a/matching/tests/test_bifiltration.cpp b/matching/tests/test_bifiltration.cpp new file mode 100644 index 0000000..742dab8 --- /dev/null +++ b/matching/tests/test_bifiltration.cpp @@ -0,0 +1,36 @@ +#include "catch/catch.hpp" + +#include +#include + +#include "common_util.h" +#include "box.h" +#include "bifiltration.h" + +using namespace md; + +//TEST_CASE("Small check", "[bifiltration][dim2]") +//{ +// Bifiltration bif("/home/narn/code/matching_distance/code/src/tests/test_bifiltration_full_triangle_phat_like.txt", BifiltrationFormat::phat_like); +// auto simplices = bif.simplices(); +// bif.sanity_check(); +// +// REQUIRE( simplices.size() == 7 ); +// +// REQUIRE( simplices[0].dim() == 0 ); +// REQUIRE( simplices[1].dim() == 0 ); +// REQUIRE( simplices[2].dim() == 0 ); +// REQUIRE( simplices[3].dim() == 1 ); +// REQUIRE( simplices[4].dim() == 1 ); +// REQUIRE( simplices[5].dim() == 1 ); +// REQUIRE( simplices[6].dim() == 2); +// +// REQUIRE( simplices[0].position() == Point(0, 0)); +// REQUIRE( simplices[1].position() == Point(0, 0)); +// REQUIRE( simplices[2].position() == Point(0, 0)); +// REQUIRE( simplices[3].position() == Point(3, 1)); +// REQUIRE( simplices[6].position() == Point(30, 40)); +// +// Line line_1(Line::pi / 2.0, 0.0); +// auto dgm = bif.slice_diagram(line_1); +//} diff --git a/matching/tests/test_bifiltration_1.txt b/matching/tests/test_bifiltration_1.txt new file mode 120000 index 0000000..ddd23e9 --- /dev/null +++ b/matching/tests/test_bifiltration_1.txt @@ -0,0 +1 @@ +../../data/test_bifiltration_1.txt \ No newline at end of file diff --git a/matching/tests/test_bifiltration_full_triangle_rene.txt b/matching/tests/test_bifiltration_full_triangle_rene.txt new file mode 120000 index 0000000..47f49fd --- /dev/null +++ b/matching/tests/test_bifiltration_full_triangle_rene.txt @@ -0,0 +1 @@ +../../data/test_bifiltration_full_triangle_rene.txt \ No newline at end of file diff --git a/matching/tests/test_common.cpp b/matching/tests/test_common.cpp new file mode 100644 index 0000000..9079a56 --- /dev/null +++ b/matching/tests/test_common.cpp @@ -0,0 +1,156 @@ +#include "catch/catch.hpp" + +#include +#include +#include + +#include "common_util.h" +#include "simplex.h" +#include "matching_distance.h" + +//using namespace md; +using Real = double; +using Point = md::Point; +using Bifiltration = md::Bifiltration; +using BifiltrationProxy = md::BifiltrationProxy; +using CalculationParams = md::CalculationParams; +using CellWithValue = md::CellWithValue; +using DualPoint = md::DualPoint; +using DualBox = md::DualBox; +using Simplex = md::Simplex; +using AbstractSimplex = md::AbstractSimplex; +using BoundStrategy = md::BoundStrategy; +using TraverseStrategy = md::TraverseStrategy; +using AxisType = md::AxisType; +using AngleType = md::AngleType; +using ValuePoint = md::ValuePoint; +using Column = md::Column; + + +TEST_CASE("AbstractSimplex", "[abstract_simplex]") +{ + AbstractSimplex as; + REQUIRE(as.dim() == -1); + + as.push_back(1); + REQUIRE(as.dim() == 0); + REQUIRE(as.facets().size() == 0); + + as.push_back(0); + REQUIRE(as.dim() == 1); + REQUIRE(as.facets().size() == 2); + REQUIRE(as.facets()[0].dim() == 0); + REQUIRE(as.facets()[1].dim() == 0); + +} + +TEST_CASE("Vertical line", "[vertical_line]") +{ + // line x = 1 + DualPoint l_vertical(AxisType::x_type, AngleType::steep, 0, 1); + + REQUIRE(l_vertical.is_vertical()); + REQUIRE(l_vertical.is_steep()); + + Point p_1(0.5, 0.5); + Point p_2(1.5, 0.5); + Point p_3(1.5, 1.5); + Point p_4(0.5, 1.5); + Point p_5(1, 10); + + REQUIRE(l_vertical.x_from_y(10) == 1); + REQUIRE(l_vertical.x_from_y(-10) == 1); + REQUIRE(l_vertical.x_from_y(0) == 1); + + REQUIRE(not l_vertical.contains(p_1)); + REQUIRE(not l_vertical.contains(p_2)); + REQUIRE(not l_vertical.contains(p_3)); + REQUIRE(not l_vertical.contains(p_4)); + REQUIRE(l_vertical.contains(p_5)); + + REQUIRE(l_vertical.goes_below(p_1)); + REQUIRE(not l_vertical.goes_below(p_2)); + REQUIRE(not l_vertical.goes_below(p_3)); + REQUIRE(l_vertical.goes_below(p_4)); + + REQUIRE(not l_vertical.goes_above(p_1)); + REQUIRE(l_vertical.goes_above(p_2)); + REQUIRE(l_vertical.goes_above(p_3)); + REQUIRE(not l_vertical.goes_above(p_4)); + +} + +TEST_CASE("Horizontal line", "[horizontal_line]") +{ + // line y = 1 + DualPoint l_horizontal(AxisType::y_type, AngleType::flat, 0, 1); + + REQUIRE(l_horizontal.is_horizontal()); + REQUIRE(l_horizontal.is_flat()); + REQUIRE(l_horizontal.y_slope() == 0); + REQUIRE(l_horizontal.y_intercept() == 1); + + Point p_1(0.5, 0.5); + Point p_2(1.5, 0.5); + Point p_3(1.5, 1.5); + Point p_4(0.5, 1.5); + Point p_5(2, 1); + + REQUIRE((not l_horizontal.contains(p_1) and + not l_horizontal.contains(p_2) and + not l_horizontal.contains(p_3) and + not l_horizontal.contains(p_4) and + l_horizontal.contains(p_5))); + + REQUIRE(not l_horizontal.goes_below(p_1)); + REQUIRE(not l_horizontal.goes_below(p_2)); + REQUIRE(l_horizontal.goes_below(p_3)); + REQUIRE(l_horizontal.goes_below(p_4)); + REQUIRE(l_horizontal.goes_below(p_5)); + + REQUIRE(l_horizontal.goes_above(p_1)); + REQUIRE(l_horizontal.goes_above(p_2)); + REQUIRE(not l_horizontal.goes_above(p_3)); + REQUIRE(not l_horizontal.goes_above(p_4)); + REQUIRE(l_horizontal.goes_above(p_5)); +} + +TEST_CASE("Flat Line with positive slope", "[flat_line]") +{ + // line y = x / 2 + 1 + DualPoint l_flat(AxisType::y_type, AngleType::flat, 0.5, 1); + + REQUIRE(not l_flat.is_horizontal()); + REQUIRE(l_flat.is_flat()); + REQUIRE(l_flat.y_slope() == 0.5); + REQUIRE(l_flat.y_intercept() == 1); + + REQUIRE(l_flat.y_from_x(0) == 1); + REQUIRE(l_flat.y_from_x(1) == 1.5); + REQUIRE(l_flat.y_from_x(2) == 2); + + Point p_1(3, 2); + Point p_2(-2, 0.01); + Point p_3(0, 1.25); + Point p_4(6, 4.5); + Point p_5(2, 2); + + REQUIRE((not l_flat.contains(p_1) and + not l_flat.contains(p_2) and + not l_flat.contains(p_3) and + not l_flat.contains(p_4) and + l_flat.contains(p_5))); + + REQUIRE(not l_flat.goes_below(p_1)); + REQUIRE(l_flat.goes_below(p_2)); + REQUIRE(l_flat.goes_below(p_3)); + REQUIRE(l_flat.goes_below(p_4)); + REQUIRE(l_flat.goes_below(p_5)); + + REQUIRE(l_flat.goes_above(p_1)); + REQUIRE(not l_flat.goes_above(p_2)); + REQUIRE(not l_flat.goes_above(p_3)); + REQUIRE(not l_flat.goes_above(p_4)); + REQUIRE(l_flat.goes_above(p_5)); + +} diff --git a/matching/tests/test_list.txt b/matching/tests/test_list.txt new file mode 100644 index 0000000..1984606 --- /dev/null +++ b/matching/tests/test_list.txt @@ -0,0 +1 @@ +prism_lesnick_1.bif prism_lesnick_2.bif 1.0 diff --git a/matching/tests/test_matching_distance.cpp b/matching/tests/test_matching_distance.cpp new file mode 100644 index 0000000..82da530 --- /dev/null +++ b/matching/tests/test_matching_distance.cpp @@ -0,0 +1,167 @@ +#include "catch/catch.hpp" + +#include +#include +#include + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +#include "common_util.h" +#include "simplex.h" +#include "matching_distance.h" + +using Real = double; +using Point = md::Point; +using Bifiltration = md::Bifiltration; +using BifiltrationProxy = md::BifiltrationProxy; +using CalculationParams = md::CalculationParams; +using CellWithValue = md::CellWithValue; +using DualPoint = md::DualPoint; +using DualBox = md::DualBox; +using Simplex = md::Simplex; +using AbstractSimplex = md::AbstractSimplex; +using BoundStrategy = md::BoundStrategy; +using TraverseStrategy = md::TraverseStrategy; +using AxisType = md::AxisType; +using AngleType = md::AngleType; +using ValuePoint = md::ValuePoint; +using Column = md::Column; + +using md::k_corner_vps; + +namespace spd = spdlog; + +TEST_CASE("Different bounds", "[bounds]") +{ + std::vector simplices; + std::vector points; + + Real max_x = 10; + Real max_y = 20; + + int simplex_id = 0; + for(int i = 0; i <= max_x; ++i) { + for(int j = 0; j <= max_y; ++j) { + Point p(i, j); + simplices.emplace_back(simplex_id++, p, 0, Column()); + points.push_back(p); + } + } + + Bifiltration bif_a(simplices.begin(), simplices.end()); + Bifiltration bif_b(simplices.begin(), simplices.end()); + + CalculationParams params; + params.initialization_depth = 2; + + BifiltrationProxy bifp_a(bif_a, params.dim); + BifiltrationProxy bifp_b(bif_b, params.dim); + + md::DistanceCalculator calc(bifp_a, bifp_b, params); + +// REQUIRE(calc.max_x_ == Approx(max_x)); +// REQUIRE(calc.max_y_ == Approx(max_y)); + + std::vector boxes; + + for(CellWithValue c : calc.get_refined_grid(5, false, false)) { + boxes.push_back(c.dual_box()); + } + + // fill in boxes and points + + for(DualBox db : boxes) { + Real local_bound = calc.get_local_dual_bound(db); + Real local_bound_refined = calc.get_local_refined_bound(db); + REQUIRE(local_bound >= local_bound_refined); + for(Point p : points) { + for(ValuePoint vp_a : k_corner_vps) { + CellWithValue dual_cell(db, 1); + DualPoint corner_a = dual_cell.value_point(vp_a); + Real wp_a = corner_a.weighted_push(p); + dual_cell.set_value_at(vp_a, wp_a); + Real point_bound = calc.get_max_displacement_single_point(dual_cell, vp_a, p); + for(ValuePoint vp_b : k_corner_vps) { + if (vp_b <= vp_a) + continue; + DualPoint corner_b = dual_cell.value_point(vp_b); + Real wp_b = corner_b.weighted_push(p); + Real diff = fabs(wp_a - wp_b); + if (not(point_bound <= Approx(local_bound_refined))) { + std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound + << ", refined local = " << local_bound_refined << std::endl; + spd::set_level(spd::level::debug); + calc.get_max_displacement_single_point(dual_cell, vp_a, p); + calc.get_local_refined_bound(db); + spd::set_level(spd::level::info); + } + + REQUIRE(point_bound <= Approx(local_bound_refined)); + REQUIRE(diff <= Approx(point_bound)); + REQUIRE(diff <= Approx(local_bound_refined)); + } + + for(DualPoint l_random : db.random_points(100)) { + Real wp_random = l_random.weighted_push(p); + Real diff = fabs(wp_a - wp_random); + if (not(diff <= Approx(point_bound))) { + if (db.critical_points(p).size() > 4) { + std::cerr << "ERROR interesting case" << std::endl; + } else { + std::cerr << "ERROR boring case" << std::endl; + } + spd::set_level(spd::level::debug); + l_random.weighted_push(p); + spd::set_level(spd::level::info); + std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound + << ", refined local = " << local_bound_refined; + std::cerr << ", random_dual_point = " << l_random << ", wp_random = " << wp_random + << ", diff = " << diff << std::endl; + } + REQUIRE(diff <= Approx(point_bound)); + } + } + } + } +} + +TEST_CASE("Bifiltrations from file", "[matching_distance][small_example][lesnick]") +{ + std::string fname_a, fname_b; + + fname_a = "../src/tests/prism_1.bif"; + fname_b = "../src/tests/prism_2.bif"; + + Bifiltration bif_a(fname_a); + Bifiltration bif_b(fname_b); + + CalculationParams params; + + std::vector bound_strategies {BoundStrategy::local_combined, + BoundStrategy::local_dual_bound_refined}; + + std::vector traverse_strategies {TraverseStrategy::breadth_first, TraverseStrategy::depth_first}; + + std::vector scaling_factors {10, 1.0}; + + for(auto bs : bound_strategies) { + for(auto ts : traverse_strategies) { + for(double lambda : scaling_factors) { + Bifiltration bif_a_copy(bif_a); + Bifiltration bif_b_copy(bif_b); + bif_a_copy.scale(lambda); + bif_b_copy.scale(lambda); + params.bound_strategy = bs; + params.traverse_strategy = ts; + params.max_depth = 7; + params.delta = 0.1; + params.dim = 1; + Real answer = matching_distance(bif_a_copy, bif_b_copy, params); + Real correct_answer = lambda * 1.0; + REQUIRE(fabs(answer - correct_answer) / correct_answer < params.delta); + } + } + } +} + diff --git a/matching/tests/tests_main.cpp b/matching/tests/tests_main.cpp new file mode 100644 index 0000000..1c77b13 --- /dev/null +++ b/matching/tests/tests_main.cpp @@ -0,0 +1,2 @@ +#define CATCH_CONFIG_MAIN +#include "catch/catch.hpp" -- cgit v1.2.3 From 490fed367bb97a96b90caa6ef04265c063d91df1 Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Sat, 7 Mar 2020 08:44:20 +0100 Subject: Fix multiple bugs in matching distance for modules, add example. --- README.txt | 15 +++++-- bottleneck/README | 2 +- matching/CMakeLists.txt | 4 ++ matching/README.md | 72 +++++++++++++++++++++++++++++++-- matching/example/matching_dist.cpp | 3 -- matching/include/matching_distance.h | 2 +- matching/include/matching_distance.hpp | 1 + matching/include/persistence_module.h | 2 +- matching/include/persistence_module.hpp | 19 ++++++--- 9 files changed, 101 insertions(+), 19 deletions(-) (limited to 'matching/CMakeLists.txt') diff --git a/README.txt b/README.txt index ae3f6a0..48a77c0 100644 --- a/README.txt +++ b/README.txt @@ -1,5 +1,6 @@ This repository contains software to compute bottleneck and Wasserstein -distances between persistence diagrams. +distances between persistence diagrams, and matching distance +between 2-parameter persistence modules and (1-critical) bi-filtrations. The software is licensed under BSD license, see license.txt file. @@ -9,13 +10,14 @@ you probably do not need to worry about that. See README files in subdirectories for usage and building. If you use Hera in your project, we would appreciate if you -cite the corresponding paper: +cite the corresponding paper. + +Bottleneck or Wasserstein distance: + Michael Kerber, Dmitriy Morozov, and Arnur Nigmetov, "Geometry Helps to Compare Persistence Diagrams.", Journal of Experimental Algorithmics, vol. 22, 2017, pp. 1--20. (conference version: ALENEX 2016). -The BibTeX is below: - @article{jea_hera, title={Geometry helps to compare persistence diagrams}, @@ -26,3 +28,8 @@ The BibTeX is below: year={2017}, publisher={ACM New York, NY, USA} } + +Matching distance: + +Michael Kerber, Arnur Nigmetov, "Efficient Approximation of the Matching +Distance for 2-parameter persistence.", SoCG 2020 diff --git a/bottleneck/README b/bottleneck/README index 8b368af..c04390b 100644 --- a/bottleneck/README +++ b/bottleneck/README @@ -1,6 +1,6 @@ Accompanying paper: M. Kerber, D. Morozov, A. Nigmetov. Geometry Helps To Compare Persistence Diagrams (ALENEX 2016, http://www.geometrie.tugraz.at/nigmetov/geom_dist.pdf) -Bug reports can be sent to "nigmetov EMAIL SIGN tugraz DOT at". +Bug reports can be sent to "anigmetov EMAIL SIGN lbl DOT gov". # Dependencies diff --git a/matching/CMakeLists.txt b/matching/CMakeLists.txt index 3ee0f6b..a391d84 100644 --- a/matching/CMakeLists.txt +++ b/matching/CMakeLists.txt @@ -48,5 +48,9 @@ endif() add_executable(matching_dist "example/matching_dist.cpp" ${MD_HEADERS} ${BT_HEADERS} ) target_link_libraries(matching_dist PUBLIC ${libraries}) +add_executable(module_example "example/module_example.cpp" ${MD_HEADERS} ${BT_HEADERS} ) +target_link_libraries(module_example PUBLIC ${libraries}) + + add_executable(matching_distance_test ${SRC_TEST_FILES} ${BT_HEADERS} ${MD_HEADERS}) target_link_libraries(matching_distance_test PUBLIC ${libraries}) diff --git a/matching/README.md b/matching/README.md index 7dc1874..fc97441 100644 --- a/matching/README.md +++ b/matching/README.md @@ -1,5 +1,69 @@ -Matching distance between bifiltrations. +# Matching distance between bifiltrations and 2-persistence modules. -Currently supports only 1-critical bi-filtrations -in PHAT-like format: boundary matrix + critical values -of a simplex in each row +## Accompanying paper +M. Kerber, A. Nigmetov, +Efficient Approximation of the Matching Distance for 2-parameter persistence. +SoCG 2020. + +Bug reports can be sent to "anigmetov EMAIL SIGN lbl DOT gov". + +## Dependencies + +* Your compiler must support C++11. +* Boost. + +## Usage: + +1. To use a standalone command-line utility matching_dist: + +`./matching_dist -d dimension -e relative_error file1 file2` + +If `relative_error` is not specified, the default 0.1 is used. +If `dimension` is not specified, the default value is 0. +Run `./matching_dist` without parameters to see other options. + +The output is an approximation of the exact distance (which is assumed to be non-zero). +Precisely: if *d_exact* is the true distance and *d_approx* is the output, then + +> | d_exact - d_approx | / d_exact < relative_error. + +Files file1 and file2 must contain 1-critical bi-filtrations in a plain text format which is similar to PHAT. The first line of a file must say *bifiltration_phat_like*. The second line contains the total number of simplices *N*. The next *N* lines contain simplices in the format *dim x y boundary*. +* *dim*: the dimension of the simplex +* *x, y*: coordinates of the critical value +* *boundary*: indices of simplices forming the boundary of the current simplex. Indices are separated by space. +* Simplices are indexed starting from 0. + +For example, the bi-filtration of a segment with vertices appearing at (0,0) and the 1-segment appearing at (3,4) shall be written as: + +> bifiltration_phat_like +> 3 +> \# lines starting with \# are ignored +> \# vertex A has dimension 0, hence no boundary, its index is 0 +> 0 0 0 +> \# vertex B has index 1 +> 0 0 0 +> \# 1-dimensional simplex {A, B} +> 1 3 4 0 1 + +2. To use from your code. + +Here you can compute the matching distance either between bi-filtrations or between persistence modules. +First, you need to include `#include "matching_distance.h"` Practically every class you need is parameterized by Real type, which should be either float or double. The header provides two functions called `matching_distance.` +See `example/module_example.cpp` for additional details. + +## License + +See `licence.txt` in the repository root folder. + +## Building + +CMakeLists.txt can be used to build the command-line utility in the standard +way. On Linux/Mac/Windows with Cygwin: +> `mkdir build` +> `cd build` +> `cmake ..` +> `make` + +On Windows with Visual Studio: use `cmake-gui` to create the solution in build directory and build it with VS. + +The library itself is header-only and does not require separate compilation. diff --git a/matching/example/matching_dist.cpp b/matching/example/matching_dist.cpp index 13e5c6d..734f6cc 100644 --- a/matching/example/matching_dist.cpp +++ b/matching/example/matching_dist.cpp @@ -13,9 +13,6 @@ #include "spdlog/spdlog.h" #include "spdlog/fmt/ostr.h" -//#include "persistence_module.h" -#include "bifiltration.h" -#include "box.h" #include "matching_distance.h" using Real = double; diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h index 276cc3a..e82a97c 100644 --- a/matching/include/matching_distance.h +++ b/matching/include/matching_distance.h @@ -152,7 +152,7 @@ namespace md { int max_depth {6}; // maximal number of refinenemnts int initialization_depth {3}; int dim {0}; // in which dim to calculate the distance; use ALL_DIMENSIONS to get max over all dims - BoundStrategy bound_strategy {BoundStrategy::bruteforce}; + BoundStrategy bound_strategy {BoundStrategy::local_combined}; TraverseStrategy traverse_strategy {TraverseStrategy::breadth_first}; bool tolerate_max_iter_exceeded {true}; Real actual_error {std::numeric_limits::max()}; diff --git a/matching/include/matching_distance.hpp b/matching/include/matching_distance.hpp index 48c8464..7cff073 100644 --- a/matching/include/matching_distance.hpp +++ b/matching/include/matching_distance.hpp @@ -193,6 +193,7 @@ namespace md { // make all coordinates non-negative auto min_coord = std::min(module_a_.minimal_coordinate(), module_b_.minimal_coordinate()); + spd::debug("in DistanceCalculator ctor, min_coord = {}", min_coord); if (min_coord < 0) { module_a_.translate(-min_coord); module_b_.translate(-min_coord); diff --git a/matching/include/persistence_module.h b/matching/include/persistence_module.h index e99771f..4a261bb 100644 --- a/matching/include/persistence_module.h +++ b/matching/include/persistence_module.h @@ -47,7 +47,7 @@ namespace md { IndexVec components_; Relation() {} - Relation(const Point& _pos, const IndexVec& _components); + Relation(const Point& _pos, const IndexVec& _components) : position_(_pos), components_(_components) {} Real get_x() const { return position_.x; } Real get_y() const { return position_.y; } diff --git a/matching/include/persistence_module.hpp b/matching/include/persistence_module.hpp index 6e49b2e..128fed9 100644 --- a/matching/include/persistence_module.hpp +++ b/matching/include/persistence_module.hpp @@ -34,10 +34,10 @@ namespace md { template void ModulePresentation::init_boundaries() { - max_x_ = std::numeric_limits::max(); - max_y_ = std::numeric_limits::max(); - min_x_ = -std::numeric_limits::max(); - min_y_ = -std::numeric_limits::max(); + max_x_ = -std::numeric_limits::max(); + max_y_ = -std::numeric_limits::max(); + min_x_ = std::numeric_limits::max(); + min_y_ = std::numeric_limits::max(); for(const auto& gen : positions_) { min_x_ = std::min(gen.x, min_x_); @@ -55,6 +55,7 @@ namespace md { generators_(_generators), relations_(_relations) { + positions_ = concat_gen_and_rel_positions(generators_, relations_); init_boundaries(); } @@ -85,6 +86,7 @@ namespace md { void ModulePresentation::project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const { + spd::debug("Enter project_generators, slice = {}", slice); size_t num_gens = generators_.size(); RealVec gen_values; @@ -97,6 +99,7 @@ namespace md { projections.reserve(num_gens); for(auto i : sorted_indices) { projections.push_back(gen_values[i]); + spd::debug("added push = {}", gen_values[i]); } } @@ -104,6 +107,8 @@ namespace md { void ModulePresentation::project_relations(const DualPoint& slice, IndexVec& sorted_rel_indices, RealVec& projections) const { + + spd::debug("Enter project_relations, slice = {}", slice); size_t num_rels = relations_.size(); RealVec rel_values; @@ -116,12 +121,14 @@ namespace md { projections.reserve(num_rels); for(auto i : sorted_rel_indices) { projections.push_back(rel_values[i]); + spd::debug("added push = {}", rel_values[i]); } } template Diagram ModulePresentation::weighted_slice_diagram(const DualPoint& slice) const { + spd::debug("Enter weighted_slice_diagram, slice = {}", slice); IndexVec sorted_gen_indices, sorted_rel_indices; RealVec gen_projections, rel_projections; @@ -138,7 +145,8 @@ namespace md { j = sorted_gen_indices[j]; } std::sort(current_relation.begin(), current_relation.end()); - phat_matrix.set_dim(i, current_relation.size()); + // modules do not have dimension, set all to 0 + phat_matrix.set_dim(i, 0); phat_matrix.set_col(i, current_relation); } @@ -154,6 +162,7 @@ namespace md { bool is_finite_pair = new_pair.second != phat::k_infinity_index; Real birth = gen_projections.at(new_pair.first); Real death = is_finite_pair ? rel_projections.at(new_pair.second) : real_inf; + spd::debug("i = {}, birth = {}, death = {}", i, new_pair.first, new_pair.second); if (birth != death) { dgm.emplace_back(birth, death); } -- cgit v1.2.3