diff options
author | Arnur Nigmetov <nigmetov@tugraz.at> | 2020-03-04 00:33:51 +0100 |
---|---|---|
committer | Arnur Nigmetov <nigmetov@tugraz.at> | 2020-03-04 00:33:51 +0100 |
commit | 3809e4071827a5959f27e472514eaed08ba6d15e (patch) | |
tree | 113fd1c373e4a04b568f8daf25324efafff6b107 /matching | |
parent | d56c07b093bfea1690a81ebbef41b8bb9c7c2464 (diff) |
Make matching distance header-only.
Diffstat (limited to 'matching')
28 files changed, 1110 insertions, 1399 deletions
diff --git a/matching/CMakeLists.txt b/matching/CMakeLists.txt index 9384328..121e25c 100644 --- a/matching/CMakeLists.txt +++ b/matching/CMakeLists.txt @@ -29,84 +29,34 @@ if (NOT WIN32) endif (NOT WIN32) file(GLOB BT_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/../bottleneck/include/*.hpp) -file(GLOB MD_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.h) +file(GLOB MD_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp) file(GLOB SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) file(GLOB SRC_TEST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/tests/*.cpp) find_package(Threads) -set(libraries ${libraries} "stdc++fs" ${CMAKE_THREAD_LIBS_INIT} ${OpenMP_CXX_LIBRARIES}) +set(libraries ${libraries} "stdc++fs" ${CMAKE_THREAD_LIBS_INIT}) find_package(OpenMP) if (OPENMP_FOUND) +set(libraries ${libraries} ${OpenMP_CXX_LIBRARIES}) set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") endif() -#add_executable(matching_distance ${SRC_FILES} ${BT_HEADERS} ${MD_HEADERS}) -add_executable(matching_distance "src/main.cpp" - "src/box.cpp" - "src/common_util.cpp" - "src/persistence_module.cpp" - "src/simplex.cpp" - "src/bifiltration.cpp" - "src/matching_distance.cpp" - "src/dual_box.cpp" - "src/dual_point.cpp" - "include/box.h" - "include/common_util.h" - "include/persistence_module.h" - "include/simplex.h" - "include/bifiltration.h" - "include/matching_distance.h" - "include/dual_box.h" - "include/dual_point.h" - ${BT_HEADERS} include/cell_with_value.h src/cell_with_value.cpp) +add_executable(matching_distance "src/main.cpp" ${MD_HEADERS} ${BT_HEADERS} ) target_link_libraries(matching_distance PUBLIC ${libraries}) -#add_executable(matching_distance_test ${SRC_TEST_FILES} ${BT_HEADERS} ${MD_HEADERS}) -add_executable(matching_distance_test ${SRC_TEST_FILES} - "src/box.cpp" - "src/common_util.cpp" - "src/persistence_module.cpp" - "src/simplex.cpp" - "src/bifiltration.cpp" - "src/matching_distance.cpp" - "src/dual_box.cpp" - "src/dual_point.cpp" - "include/box.h" - "include/common_util.h" - "include/persistence_module.h" - "include/simplex.h" - "include/bifiltration.h" - "include/matching_distance.h" - "include/dual_box.h" - "include/dual_point.h" - ${BT_HEADERS} src/tests/test_common.cpp src/common_util.cpp src/tests/test_matching_distance.cpp src/cell_with_value.cpp) +add_executable(matching_distance_test ${SRC_TEST_FILES} ${BT_HEADERS} ${MD_HEADERS}) target_link_libraries(matching_distance_test PUBLIC ${libraries}) add_executable(test_generator "src/test_generator.cpp" - "src/box.cpp" - "src/common_util.cpp" - "src/persistence_module.cpp" - "src/simplex.cpp" - "src/bifiltration.cpp" - "src/matching_distance.cpp" - "src/dual_box.cpp" - "src/dual_point.cpp" - "include/box.h" - "include/common_util.h" - "include/persistence_module.h" - "include/simplex.h" - "include/bifiltration.h" - "include/matching_distance.h" - "include/dual_box.h" - "include/dual_point.h" - ${BT_HEADERS} src/cell_with_value.cpp) + ${MD_HEADERS} + ${BT_HEADERS}) target_link_libraries(test_generator PUBLIC ${libraries}) -#add_executable(matching_distance "src/main.cpp" "src/box.cpp" "src/common_util.cpp" "src/line.cpp" "src/persistence_module.cpp" ${BT_HEADERS} ${MD_HEADERS}) +#add_executable(matching_distance "include/main.cpp" "src/box.cpp" "src/common_util.cpp" "src/line.cpp" "src/persistence_module.hpp" ${BT_HEADERS} ${MD_HEADERS}) diff --git a/matching/include/bifiltration.h b/matching/include/bifiltration.h index f505ed9..4dd8662 100644 --- a/matching/include/bifiltration.h +++ b/matching/include/bifiltration.h @@ -3,19 +3,30 @@ #include <string> #include <ostream> +#include <iostream> +#include <fstream> +#include <sstream> +#include <cassert> #include "common_util.h" #include "box.h" #include "simplex.h" #include "dual_point.h" +#include "phat/boundary_matrix.h" +#include "phat/compute_persistence_pairs.h" + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/fmt.h" +#include "spdlog/fmt/ostr.h" + +#include "common_util.h" namespace md { + template<class Real> class Bifiltration { public: - using Diagram = std::vector<std::pair<Real, Real>>; - using Box = md::Box; - using SimplexVector = std::vector<Simplex>; + using SimplexVector = std::vector<Simplex<Real>>; Bifiltration() = default; @@ -36,7 +47,7 @@ namespace md { init(); } - Diagram weighted_slice_diagram(const DualPoint& line, int dim) const; + Diagram<Real> weighted_slice_diagram(const DualPoint<Real>& line, int dim) const; SimplexVector simplices() const { return simplices_; } @@ -48,14 +59,12 @@ namespace md { Real minimal_coordinate() const; // return box that contains positions of all simplices - Box bounding_box() const; + Box<Real> bounding_box() const; void sanity_check() const; int maximal_dim() const { return maximal_dim_; } - friend std::ostream& operator<<(std::ostream& os, const Bifiltration& bif); - Real max_x() const; Real max_y() const; @@ -64,7 +73,7 @@ namespace md { Real min_y() const; - void add_simplex(Index _id, Point birth, int _dim, const Column& _bdry); + void add_simplex(Index _id, Point<Real> birth, int _dim, const Column& _bdry); void save(const std::string& filename, BifiltrationFormat format = BifiltrationFormat::rivet); // save to file @@ -72,11 +81,8 @@ namespace md { private: SimplexVector simplices_; - // axes names, for rivet bifiltration format only - std::string parameter_1_name_ {"axis_1"}; - std::string parameter_2_name_ {"axis_2"}; - Box bounding_box_; + Box<Real> bounding_box_; int maximal_dim_ {-1}; void init(); @@ -97,13 +103,15 @@ namespace md { }; - std::ostream& operator<<(std::ostream& os, const Bifiltration& bif); + template<class Real> + std::ostream& operator<<(std::ostream& os, const Bifiltration<Real>& bif); + template<class Real> class BifiltrationProxy { public: - BifiltrationProxy(const Bifiltration& bif, int dim = 0); + BifiltrationProxy(const Bifiltration<Real>& bif, int dim = 0); // return critical values of simplices that are important for current dimension (dim and dim+1) - PointVec positions() const; + PointVec<Real> positions() const; // set current dimension int set_dim(int new_dim); @@ -111,46 +119,22 @@ namespace md { int maximal_dim() const; void translate(Real a); Real minimal_coordinate() const; - Box bounding_box() const; + Box<Real> bounding_box() const; Real max_x() const; Real max_y() const; Real min_x() const; Real min_y() const; - Diagram weighted_slice_diagram(const DualPoint& slice) const; + Diagram<Real> weighted_slice_diagram(const DualPoint<Real>& slice) const; private: int dim_ { 0 }; - mutable PointVec cached_positions_; - Bifiltration bif_; + mutable PointVec<Real> cached_positions_; + Bifiltration<Real> bif_; void cache_positions() const; }; } - +#include "bifiltration.hpp" #endif //MATCHING_DISTANCE_BIFILTRATION_H - -//// The value type of OutputIterator is Simplex_in_2D_filtration -//template<typename OutputIterator> -//void read_input(std::string filename, OutputIterator out) -//{ -// std::ifstream ifstr; -// ifstr.open(filename.c_str()); -// long n; -// ifstr >> n; // number of simplices is the first number in file -// -// Index k; // used in loop -// for (int i = 0; i < n; i++) { -// Simplex_in_2D_filtration next; -// next.index = i; -// ifstr >> next.dim >> next.pos.x >> next.pos.y; -// if (next.dim > 0) { -// for (int j = 0; j <= next.dim; j++) { -// ifstr >> k; -// next.bd.push_back(k); -// } -// } -// *out++ = next; -// } -//} diff --git a/matching/src/bifiltration.cpp b/matching/include/bifiltration.hpp index 44b12cf..9e2a82e 100644 --- a/matching/src/bifiltration.cpp +++ b/matching/include/bifiltration.hpp @@ -1,35 +1,20 @@ -#include <iostream> -#include <fstream> -#include <sstream> -#include <cassert> - -#include<phat/boundary_matrix.h> -#include<phat/compute_persistence_pairs.h> - -#include "spdlog/spdlog.h" -#include "spdlog/fmt/fmt.h" -#include "spdlog/fmt/ostr.h" - -#include "common_util.h" -#include "bifiltration.h" - -namespace spd = spdlog; - namespace md { - void Bifiltration::init() + template<class Real> + void Bifiltration<Real>::init() { - Point lower_left = max_point(); - Point upper_right = min_point(); + auto lower_left = max_point<Real>(); + auto upper_right = min_point<Real>(); for(const auto& simplex : simplices_) { - lower_left = greatest_lower_bound(lower_left, simplex.position()); - upper_right = least_upper_bound(upper_right, simplex.position()); + lower_left = greatest_lower_bound<>(lower_left, simplex.position()); + upper_right = least_upper_bound<>(upper_right, simplex.position()); maximal_dim_ = std::max(maximal_dim_, simplex.dim()); } - bounding_box_ = Box(lower_left, upper_right); + bounding_box_ = Box<Real>(lower_left, upper_right); } - Bifiltration::Bifiltration(const std::string& fname ) + template<class Real> + Bifiltration<Real>::Bifiltration(const std::string& fname) { std::ifstream ifstr {fname.c_str()}; if (!ifstr.good()) { @@ -69,12 +54,13 @@ namespace md { init(); } - void Bifiltration::rivet_format_reader(std::ifstream& ifstr) + template<class Real> + void Bifiltration<Real>::rivet_format_reader(std::ifstream& ifstr) { std::string s; - // read axes names - std::getline(ifstr, parameter_1_name_); - std::getline(ifstr, parameter_2_name_); + // read axes names, ignore them + std::getline(ifstr, s); + std::getline(ifstr, s); Index index = 0; while(std::getline(ifstr, s)) { @@ -84,7 +70,8 @@ namespace md { } } - void Bifiltration::phat_like_format_reader(std::ifstream& ifstr) + template<class Real> + void Bifiltration<Real>::phat_like_format_reader(std::ifstream& ifstr) { spd::debug("Enter phat_like_format_reader"); // read stream line by line; do not use >> operator @@ -105,7 +92,8 @@ namespace md { spd::debug("Read {} simplices from file", n_simplices); } - void Bifiltration::scale(Real lambda) + template<class Real> + void Bifiltration<Real>::scale(Real lambda) { for(auto& s : simplices_) { s.scale(lambda); @@ -113,10 +101,11 @@ namespace md { init(); } - void Bifiltration::sanity_check() const + template<class Real> + void Bifiltration<Real>::sanity_check() const { #ifdef DEBUG - spd::debug("Enter Bifiltration::sanity_check"); + spd::debug("Enter Bifiltration<Real>::sanity_check"); // check that boundary has correct number of simplices, // each bounding simplex has correct dim // and appears in the filtration before the simplex it bounds @@ -129,16 +118,17 @@ namespace md { assert(bdry_simplex.position().is_less(s.position(), false)); } } - spd::debug("Exit Bifiltration::sanity_check"); + spd::debug("Exit Bifiltration<Real>::sanity_check"); #endif } - Diagram Bifiltration::weighted_slice_diagram(const DualPoint& line, int dim) const + template<class Real> + Diagram<Real> Bifiltration<Real>::weighted_slice_diagram(const DualPoint<Real>& line, int dim) const { - DiagramKeeper dgm; + DiagramKeeper<Real> dgm; // make a copy for now; I want slice_diagram to be const - std::vector<Simplex> simplices(simplices_); + std::vector<Simplex<Real>> simplices(simplices_); // std::vector<Simplex> simplices; // simplices.reserve(simplices_.size() / 2); @@ -156,7 +146,7 @@ namespace md { } std::sort(simplices.begin(), simplices.end(), - [](const Simplex& a, const Simplex& b) { return a.value() < b.value(); }); + [](const Simplex<Real>& a, const Simplex<Real>& b) { return a.value() < b.value(); }); std::map<Index, Index> index_map; for(Index i = 0; i < (int) simplices.size(); i++) { index_map[simplices[i].id()] = i; @@ -202,17 +192,20 @@ namespace md { return dgm.get_diagram(dim); } - Box Bifiltration::bounding_box() const + template<class Real> + Box<Real> Bifiltration<Real>::bounding_box() const { return bounding_box_; } - Real Bifiltration::minimal_coordinate() const + template<class Real> + Real Bifiltration<Real>::minimal_coordinate() const { return std::min(bounding_box_.lower_left().x, bounding_box_.lower_left().y); } - void Bifiltration::translate(Real a) + template<class Real> + void Bifiltration<Real>::translate(Real a) { bounding_box_.translate(a); for(auto& simplex : simplices_) { @@ -220,7 +213,8 @@ namespace md { } } - Real Bifiltration::max_x() const + template<class Real> + Real Bifiltration<Real>::max_x() const { if (simplices_.empty()) return 1; @@ -230,7 +224,8 @@ namespace md { return me->position().x; } - Real Bifiltration::max_y() const + template<class Real> + Real Bifiltration<Real>::max_y() const { if (simplices_.empty()) return 1; @@ -240,7 +235,8 @@ namespace md { return me->position().y; } - Real Bifiltration::min_x() const + template<class Real> + Real Bifiltration<Real>::min_x() const { if (simplices_.empty()) return 0; @@ -250,7 +246,8 @@ namespace md { return me->position().x; } - Real Bifiltration::min_y() const + template<class Real> + Real Bifiltration<Real>::min_y() const { if (simplices_.empty()) return 0; @@ -260,12 +257,14 @@ namespace md { return me->position().y; } - void Bifiltration::add_simplex(md::Index _id, md::Point birth, int _dim, const md::Column& _bdry) + template<class Real> + void Bifiltration<Real>::add_simplex(Index _id, Point<Real> birth, int _dim, const Column& _bdry) { simplices_.emplace_back(_id, birth, _dim, _bdry); } - void Bifiltration::save(const std::string& filename, md::BifiltrationFormat format) + template<class Real> + void Bifiltration<Real>::save(const std::string& filename, md::BifiltrationFormat format) { switch(format) { case BifiltrationFormat::rivet: @@ -292,7 +291,8 @@ namespace md { } } - void Bifiltration::postprocess_rivet_format() + template<class Real> + void Bifiltration<Real>::postprocess_rivet_format() { std::map<Column, Index> facets_to_ids; @@ -324,16 +324,19 @@ namespace md { } // loop over simplices } - std::ostream& operator<<(std::ostream& os, const Bifiltration& bif) + template<class Real> + std::ostream& operator<<(std::ostream& os, const Bifiltration<Real>& bif) { - os << "Bifiltration, axes = " << bif.parameter_1_name_ << ", " << bif.parameter_2_name_ << std::endl; + os << "Bifiltration [" << std::endl; for(const auto& s : bif.simplices()) { os << s << std::endl; } + os << "]" << std::endl; return os; } - BifiltrationProxy::BifiltrationProxy(const md::Bifiltration& bif, int dim) + template<class Real> + BifiltrationProxy<Real>::BifiltrationProxy(const Bifiltration<Real>& bif, int dim) : dim_(dim), bif_(bif) @@ -341,7 +344,8 @@ namespace md { cache_positions(); } - void BifiltrationProxy::cache_positions() const + template<class Real> + void BifiltrationProxy<Real>::cache_positions() const { cached_positions_.clear(); for(const auto& simplex : bif_.simplices()) { @@ -350,7 +354,9 @@ namespace md { } } - PointVec BifiltrationProxy::positions() const + template<class Real> + PointVec<Real> + BifiltrationProxy<Real>::positions() const { if (cached_positions_.empty()) { cache_positions(); @@ -359,46 +365,54 @@ namespace md { } // translate all points by vector (a,a) - void BifiltrationProxy::translate(Real a) + template<class Real> + void BifiltrationProxy<Real>::translate(Real a) { bif_.translate(a); } // return minimal value of x- and y-coordinates // among all simplices - Real BifiltrationProxy::minimal_coordinate() const + template<class Real> + Real BifiltrationProxy<Real>::minimal_coordinate() const { return bif_.minimal_coordinate(); } // return box that contains positions of all simplices - Box BifiltrationProxy::bounding_box() const + template<class Real> + Box<Real> BifiltrationProxy<Real>::bounding_box() const { return bif_.bounding_box(); } - Real BifiltrationProxy::max_x() const + template<class Real> + Real BifiltrationProxy<Real>::max_x() const { return bif_.max_x(); } - Real BifiltrationProxy::max_y() const + template<class Real> + Real BifiltrationProxy<Real>::max_y() const { return bif_.max_y(); } - Real BifiltrationProxy::min_x() const + template<class Real> + Real BifiltrationProxy<Real>::min_x() const { return bif_.min_x(); } - Real BifiltrationProxy::min_y() const + template<class Real> + Real BifiltrationProxy<Real>::min_y() const { return bif_.min_y(); } - Diagram BifiltrationProxy::weighted_slice_diagram(const DualPoint& slice) const + template<class Real> + Diagram<Real> BifiltrationProxy<Real>::weighted_slice_diagram(const DualPoint<Real>& slice) const { return bif_.weighted_slice_diagram(slice, dim_); } diff --git a/matching/include/box.h b/matching/include/box.h index 2990fba..4243667 100644 --- a/matching/include/box.h +++ b/matching/include/box.h @@ -8,20 +8,23 @@ namespace md { + template<class Real_> struct Box { + public: + using Real = Real_; private: - Point ll; - Point ur; + Point<Real> ll; + Point<Real> ur; public: - Box(Point ll = Point(), Point ur = Point()) + Box(Point<Real> ll = Point<Real>(), Point<Real> ur = Point<Real>()) :ll(ll), ur(ur) { } - Box(Point center, Real width, Real height) : - ll(Point(center.x - 0.5 * width, center.y - 0.5 * height)), - ur(Point(center.x + 0.5 * width, center.y + 0.5 * height)) + Box(Point<Real> center, Real width, Real height) : + ll(Point<Real>(center.x - 0.5 * width, center.y - 0.5 * height)), + ur(Point<Real>(center.x + 0.5 * width, center.y + 0.5 * height)) { } @@ -30,11 +33,9 @@ namespace md { inline double height() const { return ur.y - ll.y; } - inline Point lower_left() const { return ll; } - inline Point upper_right() const { return ur; } - inline Point center() const { return Point((ll.x + ur.x) / 2, (ll.y + ur.y) / 2); } - -// bool inside(Point& p) const { return ll.x <= p.x && ll.y <= p.y && ur.x >= p.x && ur.y >= p.y; } + inline Point<Real> lower_left() const { return ll; } + inline Point<Real> upper_right() const { return ur; } + inline Point<Real> center() const { return Point<Real>((ll.x + ur.x) / 2, (ll.y + ur.y) / 2); } inline bool operator==(const Box& p) { @@ -43,58 +44,16 @@ namespace md { std::vector<Box> refine() const; - std::vector<Point> corners() const; + std::vector<Point<Real>> corners() const; void translate(Real a); - - // return minimal and maximal value of func - // on the corners of the box - template<typename F> - std::pair<Real, Real> min_max_on_corners(const F& func) const; - - friend std::ostream& operator<<(std::ostream& os, const Box& box); }; - std::ostream& operator<<(std::ostream& os, const Box& box); -// template<typename InputIterator> -// Box compute_bounding_box(InputIterator simplices_begin, InputIterator simplices_end) -// { -// if (simplices_begin == simplices_end) { -// return Box(); -// } -// Box bb; -// bb.ll = bb.ur = simplices_begin->pos; -// for (InputIterator it = simplices_begin; it != simplices_end; it++) { -// Point& pos = it->pos; -// if (pos.x < bb.ll.x) { -// bb.ll.x = pos.x; -// } -// if (pos.y < bb.ll.y) { -// bb.ll.y = pos.y; -// } -// if (pos.x > bb.ur.x) { -// bb.ur.x = pos.x; -// } -// if (pos.y > bb.ur.y) { -// bb.ur.y = pos.y; -// } -// } -// return bb; -// } - - Box get_enclosing_box(const Box& box_a, const Box& box_b); - - template<typename F> - std::pair<Real, Real> Box::min_max_on_corners(const F& func) const - { - std::pair<Real, Real> min_max { std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::max() }; - for(Point p : corners()) { - Real value = func(p); - min_max.first = std::min(min_max.first, value); - min_max.second = std::max(min_max.second, value); - } - return min_max; - }; + template<class Real> + std::ostream& operator<<(std::ostream& os, const Box<Real>& box); + } // namespace md +#include "box.hpp" + #endif //MATCHING_DISTANCE_BOX_H diff --git a/matching/include/box.hpp b/matching/include/box.hpp new file mode 100644 index 0000000..f551d84 --- /dev/null +++ b/matching/include/box.hpp @@ -0,0 +1,52 @@ +namespace md { + + template<class Real> + std::ostream& operator<<(std::ostream& os, const Box<Real>& box) + { + os << "Box(lower_left = " << box.lower_left() << ", upper_right = " << box.upper_right() << ")"; + return os; + } + + template<class Real> + void Box<Real>::translate(Real a) + { + ll.x += a; + ll.y += a; + ur.x += a; + ur.y += a; + } + + template<class Real> + std::vector<Box<Real>> Box<Real>::refine() const + { + std::vector<Box<Real>> result; + +// 1 | 2 +// 0 | 3 + + Point<Real> new_ll = lower_left(); + Point<Real> new_ur = center(); + result.emplace_back(new_ll, new_ur); + + new_ll.y = center().y; + new_ur.y = ur.y; + result.emplace_back(new_ll, new_ur); + + new_ll = center(); + new_ur = upper_right(); + result.emplace_back(new_ll, new_ur); + + new_ll.y = ll.y; + new_ur.y = center().y; + result.emplace_back(new_ll, new_ur); + + return result; + } + + template<class Real> + std::vector<Point<Real>> Box<Real>::corners() const + { + return {ll, Point<Real>(ll.x, ur.y), ur, Point<Real>(ur.x, ll.y)}; + }; + +} diff --git a/matching/include/cell_with_value.h b/matching/include/cell_with_value.h index 25644d1..3548a11 100644 --- a/matching/include/cell_with_value.h +++ b/matching/include/cell_with_value.h @@ -1,7 +1,3 @@ -// -// Created by narn on 16.07.19. -// - #ifndef MATCHING_DISTANCE_CELL_WITH_VALUE_H #define MATCHING_DISTANCE_CELL_WITH_VALUE_H @@ -21,7 +17,29 @@ namespace md { upper_right }; - std::ostream& operator<<(std::ostream& os, const ValuePoint& vp); + inline std::ostream& operator<<(std::ostream& os, const ValuePoint& vp) + { + switch(vp) { + case ValuePoint::upper_left : + os << "upper_left"; + break; + case ValuePoint::upper_right : + os << "upper_right"; + break; + case ValuePoint::lower_left : + os << "lower_left"; + break; + case ValuePoint::lower_right : + os << "lower_right"; + break; + case ValuePoint::center: + os << "center"; + break; + default: + os << "FORGOTTEN ValuePoint"; + } + return os; + } const std::vector<ValuePoint> k_all_vps = {ValuePoint::center, ValuePoint::lower_left, ValuePoint::upper_left, ValuePoint::upper_right, ValuePoint::lower_right}; @@ -31,8 +49,10 @@ namespace md { // represents a cell in the dual space with the value // of the weighted bottleneck distance + template<class Real_> class CellWithValue { public: + using Real = Real_; CellWithValue() = default; @@ -44,18 +64,18 @@ namespace md { CellWithValue& operator=(CellWithValue&& other) = default; - CellWithValue(const DualBox& b, int level) + CellWithValue(const DualBox<Real>& b, int level) :dual_box_(b), level_(level) { } - DualBox dual_box() const { return dual_box_; } + DualBox<Real> dual_box() const { return dual_box_; } - DualPoint center() const { return dual_box_.center(); } + DualPoint<Real> center() const { return dual_box_.center(); } Real value_at(ValuePoint vp) const; bool has_value_at(ValuePoint vp) const; - DualPoint value_point(ValuePoint vp) const; + DualPoint<Real> value_point(ValuePoint vp) const; int level() const { return level_; } @@ -73,8 +93,6 @@ namespace md { std::vector<CellWithValue> get_refined_cells() const; - friend std::ostream& operator<<(std::ostream&, const CellWithValue&); - void set_max_possible_value(Real new_upper_bound); int num_values() const; @@ -100,7 +118,7 @@ namespace md { bool has_upper_right_value() const { return upper_right_value_ >= 0; } - DualBox dual_box_; + DualBox<Real> dual_box_; Real central_value_ {-1.0}; Real lower_left_value_ {-1.0}; Real lower_right_value_ {-1.0}; @@ -114,7 +132,10 @@ namespace md { bool has_max_possible_value_ {false}; }; - std::ostream& operator<<(std::ostream& os, const CellWithValue& cell); + template<class Real> + std::ostream& operator<<(std::ostream& os, const CellWithValue<Real>& cell); } // namespace md +#include "cell_with_value.hpp" + #endif //MATCHING_DISTANCE_CELL_WITH_VALUE_H diff --git a/matching/src/cell_with_value.cpp b/matching/include/cell_with_value.hpp index d8fd7d4..88b2569 100644 --- a/matching/src/cell_with_value.cpp +++ b/matching/include/cell_with_value.hpp @@ -1,17 +1,11 @@ -#include <spdlog/spdlog.h> -#include <spdlog/fmt/ostr.h> - -namespace spd = spdlog; - -#include "cell_with_value.h" - namespace md { #ifdef MD_DEBUG - long long int CellWithValue::max_id = 0; + long long int CellWithValue<Real>::max_id = 0; #endif - Real CellWithValue::value_at(ValuePoint vp) const + template<class Real> + Real CellWithValue<Real>::value_at(ValuePoint vp) const { switch(vp) { case ValuePoint::upper_left : @@ -29,7 +23,8 @@ namespace md { return 1.0 / 0.0; } - bool CellWithValue::has_value_at(ValuePoint vp) const + template<class Real> + bool CellWithValue<Real>::has_value_at(ValuePoint vp) const { switch(vp) { case ValuePoint::upper_left : @@ -45,9 +40,10 @@ namespace md { } // to shut up compiler warning return 1.0 / 0.0; - } + } - DualPoint CellWithValue::value_point(md::ValuePoint vp) const + template<class Real> + DualPoint<Real> CellWithValue<Real>::value_point(md::ValuePoint vp) const { switch(vp) { case ValuePoint::upper_left : @@ -62,27 +58,31 @@ namespace md { return dual_box().center(); } // to shut up compiler warning - return DualPoint(); - } + return DualPoint<Real>(); + } - bool CellWithValue::has_corner_value() const + template<class Real> + bool CellWithValue<Real>::has_corner_value() const { return has_lower_left_value() or has_lower_right_value() or has_upper_left_value() or has_upper_right_value(); } - Real CellWithValue::stored_upper_bound() const + template<class Real> + Real CellWithValue<Real>::stored_upper_bound() const { assert(has_max_possible_value_); return max_possible_value_; } - Real CellWithValue::max_corner_value() const + template<class Real> + Real CellWithValue<Real>::max_corner_value() const { return std::max({lower_left_value_, lower_right_value_, upper_left_value_, upper_right_value_}); } - Real CellWithValue::min_value() const + template<class Real> + Real CellWithValue<Real>::min_value() const { Real result = std::numeric_limits<Real>::max(); for(auto vp : k_all_vps) { @@ -94,13 +94,14 @@ namespace md { return result; } - std::vector<CellWithValue> CellWithValue::get_refined_cells() const + template<class Real> + std::vector<CellWithValue<Real>> CellWithValue<Real>::get_refined_cells() const { - std::vector<CellWithValue> result; + std::vector<CellWithValue<Real>> result; result.reserve(4); for(const auto& refined_box : dual_box_.refine()) { - CellWithValue refined_cell(refined_box, level() + 1); + CellWithValue<Real> refined_cell(refined_box, level() + 1); #ifdef MD_DEBUG refined_cell.parent_ids = parent_ids; @@ -142,10 +143,11 @@ namespace md { return result; } - void CellWithValue::set_value_at(md::ValuePoint vp, md::Real new_value) + template<class Real> + void CellWithValue<Real>::set_value_at(ValuePoint vp, Real new_value) { if (has_value_at(vp)) - spd::error("CellWithValue: trying to re-assign value!, this = {}, vp = {}", *this, vp); + spd::error("CellWithValue<Real>: trying to re-assign value!, this = {}, vp = {}", *this, vp); switch(vp) { case ValuePoint::upper_left : @@ -164,11 +166,10 @@ namespace md { central_value_ = new_value; break; } - - } - - int CellWithValue::num_values() const + + template<class Real> + int CellWithValue<Real>::num_values() const { int result = 0; for(ValuePoint vp : k_all_vps) { @@ -177,8 +178,9 @@ namespace md { return result; } - - void CellWithValue::set_max_possible_value(Real new_upper_bound) + + template<class Real> + void CellWithValue<Real>::set_max_possible_value(Real new_upper_bound) { assert(new_upper_bound >= central_value_); assert(new_upper_bound >= lower_left_value_); @@ -189,33 +191,9 @@ namespace md { max_possible_value_ = new_upper_bound; } - std::ostream& operator<<(std::ostream& os, const ValuePoint& vp) - { - switch(vp) { - case ValuePoint::upper_left : - os << "upper_left"; - break; - case ValuePoint::upper_right : - os << "upper_right"; - break; - case ValuePoint::lower_left : - os << "lower_left"; - break; - case ValuePoint::lower_right : - os << "lower_right"; - break; - case ValuePoint::center: - os << "center"; - break; - default: - os << "FORGOTTEN ValuePoint"; - } - return os; - } - - - std::ostream& operator<<(std::ostream& os, const CellWithValue& cell) + template<class Real> + std::ostream& operator<<(std::ostream& os, const CellWithValue<Real>& cell) { os << "CellWithValue(box = " << cell.dual_box() << ", "; @@ -244,4 +222,3 @@ namespace md { } } // namespace md - diff --git a/matching/include/common_defs.h b/matching/include/common_defs.h index 3f3d937..8d01325 100644 --- a/matching/include/common_defs.h +++ b/matching/include/common_defs.h @@ -1,8 +1,8 @@ #ifndef MATCHING_DISTANCE_DEF_DEBUG_H #define MATCHING_DISTANCE_DEF_DEBUG_H -//#define EXPERIMENTAL_TIMING -//#define PRINT_HEAT_MAP +//#define MD_EXPERIMENTAL_TIMING +//#define MD_PRINT_HEAT_MAP //#define MD_DEBUG //#define MD_DO_CHECKS //#define MD_DO_FULL_CHECK diff --git a/matching/include/common_util.h b/matching/include/common_util.h index 2d8dcb0..778536f 100644 --- a/matching/include/common_util.h +++ b/matching/include/common_util.h @@ -11,22 +11,24 @@ #include <map> #include <functional> +#include <spdlog/spdlog.h> +#include <spdlog/fmt/ostr.h> + +namespace spd = spdlog; + #include "common_defs.h" #include "phat/helpers/misc.h" - namespace md { - - using Real = double; - using RealVec = std::vector<Real>; using Index = phat::index; using IndexVec = std::vector<Index>; - static constexpr Real pi = M_PI; + //static constexpr Real pi = M_PI; using Column = std::vector<Index>; + template<class Real> struct Point { Real x; Real y; @@ -71,59 +73,56 @@ namespace md { }; - using PointVec = std::vector<Point>; - - Point operator+(const Point& u, const Point& v); + template<class Real> + using PointVec = std::vector<Point<Real>>; - Point operator-(const Point& u, const Point& v); + template<class Real> + Point<Real> operator+(const Point<Real>& u, const Point<Real>& v); - Point least_upper_bound(const Point& u, const Point& v); + template<class Real> + Point<Real> operator-(const Point<Real>& u, const Point<Real>& v); - Point greatest_lower_bound(const Point& u, const Point& v); - Point max_point(); + template<class Real> + Point<Real> least_upper_bound(const Point<Real>& u, const Point<Real>& v); - Point min_point(); + template<class Real> + Point<Real> greatest_lower_bound(const Point<Real>& u, const Point<Real>& v); - std::ostream& operator<<(std::ostream& ostr, const Point& vec); + template<class Real> + Point<Real> max_point(); - Real L_infty(const Point& v); + template<class Real> + Point<Real> min_point(); - Real l_2_norm(const Point& v); + template<class Real> + std::ostream& operator<<(std::ostream& ostr, const Point<Real>& vec); - Real l_2_dist(const Point& x, const Point& y); + template<class Real> + using DiagramPoint = std::pair<Real, Real>; - Real l_infty_dist(const Point& x, const Point& y); + template<class Real> + using Diagram = std::vector<DiagramPoint<Real>>; - using Interval = std::pair<Real, Real>; - - // return minimal interval that contains both a and b - inline Interval minimal_covering_interval(Interval a, Interval b) - { - return {std::min(a.first, b.first), std::max(a.second, b.second)}; - } // to keep diagrams in all dimensions // TODO: store in Hera format? + template<class Real> class DiagramKeeper { public: - using DiagramPoint = std::pair<Real, Real>; - using Diagram = std::vector<DiagramPoint>; DiagramKeeper() { }; void add_point(int dim, Real birth, Real death); - Diagram get_diagram(int dim) const; + Diagram<Real> get_diagram(int dim) const; void clear() { data_.clear(); } private: - std::map<int, Diagram> data_; + std::map<int, Diagram<Real>> data_; }; - using Diagram = std::vector<std::pair<Real, Real>>; - template<typename C> std::string container_to_string(const C& cont) { @@ -140,42 +139,18 @@ namespace md { return ss.str(); } - int gcd(int a, int b); - - struct Rational { - int numerator {0}; - int denominator {1}; - Rational() = default; - Rational(int n, int d) : numerator(n / gcd(n, d)), denominator(d / gcd(n, d)) {} - Rational(std::pair<int, int> p) : Rational(p.first, p.second) {} - Rational(int n) : numerator(n), denominator(1) {} - Real to_real() const { return (Real)numerator / (Real)denominator; } - void reduce(); - Rational& operator+=(const Rational& rhs); - Rational& operator-=(const Rational& rhs); - Rational& operator*=(const Rational& rhs); - Rational& operator/=(const Rational& rhs); - }; - - using namespace std::rel_ops; - - bool operator==(const Rational& a, const Rational& b); - bool operator<(const Rational& a, const Rational& b); - std::ostream& operator<<(std::ostream& os, const Rational& a); - - // arithmetic - Rational operator+(Rational a, const Rational& b); - Rational operator-(Rational a, const Rational& b); - Rational operator*(Rational a, const Rational& b); - Rational operator/(Rational a, const Rational& b); - - Rational reduce(Rational frac); - - Rational midpoint(Rational a, Rational b); - // return true, if s is empty or starts with # (commented out line) // whitespaces in the beginning of s are ignored - bool ignore_line(const std::string& s); + inline bool ignore_line(const std::string& s) + { + for(auto c : s) { + if (isspace(c)) + continue; + return (c == '#'); + } + return true; + } + // split string by delimeter template<typename Out> @@ -195,10 +170,10 @@ namespace md { } namespace std { - template<> - struct hash<md::Point> + template<class Real> + struct hash<md::Point<Real>> { - std::size_t operator()(const md::Point& p) const + std::size_t operator()(const md::Point<Real>& p) const { auto hx = std::hash<decltype(p.x)>()(p.x); auto hy = std::hash<decltype(p.y)>()(p.y); @@ -207,5 +182,7 @@ namespace std { }; }; +#include "common_util.hpp" + #endif //MATCHING_DISTANCE_COMMON_UTIL_H diff --git a/matching/include/common_util.hpp b/matching/include/common_util.hpp new file mode 100644 index 0000000..76d97af --- /dev/null +++ b/matching/include/common_util.hpp @@ -0,0 +1,96 @@ +#include <vector> +#include <utility> +#include <cmath> +#include <ostream> +#include <limits> +#include <algorithm> + +#include <common_util.h> + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +namespace md { + + template<class Real> + Point<Real> operator+(const Point<Real>& u, const Point<Real>& v) + { + return Point<Real>(u.x + v.x, u.y + v.y); + } + + template<class Real> + Point<Real> operator-(const Point<Real>& u, const Point<Real>& v) + { + return Point<Real>(u.x - v.x, u.y - v.y); + } + + template<class Real> + Point<Real> least_upper_bound(const Point<Real>& u, const Point<Real>& v) + { + return Point<Real>(std::max(u.x, v.x), std::max(u.y, v.y)); + } + + template<class Real> + Point<Real> greatest_lower_bound(const Point<Real>& u, const Point<Real>& v) + { + return Point<Real>(std::min(u.x, v.x), std::min(u.y, v.y)); + } + + template<class Real> + Point<Real> max_point() + { + return Point<Real>(std::numeric_limits<Real>::max(), std::numeric_limits<Real>::min()); + } + + template<class Real> + Point<Real> min_point() + { + return Point<Real>(-std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::min()); + } + + template<class Real> + std::ostream& operator<<(std::ostream& ostr, const Point<Real>& vec) + { + ostr << "(" << vec.x << ", " << vec.y << ")"; + return ostr; + } + + template<class Real> + Real l_infty_norm(const Point<Real>& v) + { + return std::max(std::abs(v.x), std::abs(v.y)); + } + + template<class Real> + Real l_2_norm(const Point<Real>& v) + { + return v.norm(); + } + + template<class Real> + Real l_2_dist(const Point<Real>& x, const Point<Real>& y) + { + return l_2_norm(x - y); + } + + template<class Real> + Real l_infty_dist(const Point<Real>& x, const Point<Real>& y) + { + return l_infty_norm(x - y); + } + + template<class Real> + void DiagramKeeper<Real>::add_point(int dim, Real birth, Real death) + { + data_[dim].emplace_back(birth, death); + } + + template<class Real> + Diagram<Real> DiagramKeeper<Real>::get_diagram(int dim) const + { + if (data_.count(dim) == 1) + return data_.at(dim); + else + return Diagram<Real>(); + } +} diff --git a/matching/include/dual_box.h b/matching/include/dual_box.h index ce0384d..0e4f4d5 100644 --- a/matching/include/dual_box.h +++ b/matching/include/dual_box.h @@ -4,16 +4,23 @@ #include <ostream> #include <limits> #include <vector> +#include <random> + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + #include "common_util.h" #include "dual_point.h" namespace md { + + template<class Real> class DualBox { public: - DualBox(DualPoint ll, DualPoint ur); + DualBox(DualPoint<Real> ll, DualPoint<Real> ur); DualBox() = default; DualBox(const DualBox&) = default; @@ -23,12 +30,12 @@ namespace md { DualBox& operator=(DualBox&& other) = default; - DualPoint center() const { return midpoint(lower_left_, upper_right_); } - DualPoint lower_left() const { return lower_left_; } - DualPoint upper_right() const { return upper_right_; } + DualPoint<Real> center() const { return midpoint(lower_left_, upper_right_); } + DualPoint<Real> lower_left() const { return lower_left_; } + DualPoint<Real> upper_right() const { return upper_right_; } - DualPoint lower_right() const; - DualPoint upper_left() const; + DualPoint<Real> lower_right() const; + DualPoint<Real> upper_left() const; AxisType axis_type() const { return lower_left_.axis_type(); } AngleType angle_type() const { return lower_left_.angle_type(); } @@ -42,66 +49,35 @@ namespace md { bool is_flat() const { return upper_right_.is_flat(); } bool is_steep() const { return lower_left_.is_steep(); } - // return minimal and maximal value of func - // on the corners of the box - template<typename F> - std::pair<Real, Real> min_max_on_corners(const F& func) const; - - template<typename F> - Real max_abs_value(const F& func) const; - - std::vector<DualBox> refine() const; - std::vector<DualPoint> corners() const; - std::vector<DualPoint> critical_points(const Point& p) const; + std::vector<DualPoint<Real>> corners() const; + std::vector<DualPoint<Real>> critical_points(const Point<Real>& p) const; // sample n points from the box uniformly; for tests - std::vector<DualPoint> random_points(int n) const; + std::vector<DualPoint<Real>> random_points(int n) const; // return 2 dual points at the boundary // where push changes from horizontal to vertical - std::vector<DualPoint> push_change_points(const Point& p) const; - - friend std::ostream& operator<<(std::ostream& os, const DualBox& db); + std::vector<DualPoint<Real>> push_change_points(const Point<Real>& p) const; // check that a has same sign, angles are all flat or all steep bool sanity_check() const; - bool contains(const DualPoint& dp) const; + bool contains(const DualPoint<Real>& dp) const; bool operator==(const DualBox& other) const; private: - DualPoint lower_left_; - DualPoint upper_right_; + DualPoint<Real> lower_left_; + DualPoint<Real> upper_right_; }; - std::ostream& operator<<(std::ostream& os, const DualBox& db); - - template<typename F> - std::pair<Real, Real> DualBox::min_max_on_corners(const F& func) const + template<class Real> + std::ostream& operator<<(std::ostream& os, const DualBox<Real>& db) { - std::pair<Real, Real> min_max { std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::max() }; - for(auto p : corners()) { - Real value = func(p); - min_max.first = std::min(min_max.first, value); - min_max.second = std::max(min_max.second, value); - } - return min_max; - }; - - - template<typename F> - Real DualBox::max_abs_value(const F& func) const - { - Real result = 0; - for(auto p_1 : corners()) { - for(auto p_2 : corners()) { - Real value = fabs(func(p_1, p_2)); - result = std::max(value, result); - } - } - return result; - }; - + os << "DualBox(" << db.lower_left() << ", " << db.upper_right() << ")"; + return os; + } } +#include "dual_box.hpp" + #endif //MATCHING_DISTANCE_DUAL_BOX_H diff --git a/matching/src/dual_box.cpp b/matching/include/dual_box.hpp index ff4d30c..85f7f27 100644 --- a/matching/src/dual_box.cpp +++ b/matching/include/dual_box.hpp @@ -1,36 +1,24 @@ -#include <random> - -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" - -namespace spd = spdlog; - -#include "dual_box.h" - namespace md { - std::ostream& operator<<(std::ostream& os, const DualBox& db) - { - os << "DualBox(" << db.lower_left_ << ", " << db.upper_right_ << ")"; - return os; - } - - DualBox::DualBox(DualPoint ll, DualPoint ur) + template<class Real> + DualBox<Real>::DualBox(DualPoint<Real> ll, DualPoint<Real> ur) :lower_left_(ll), upper_right_(ur) { } - std::vector<DualPoint> DualBox::corners() const + template<class Real> + std::vector<DualPoint<Real>> DualBox<Real>::corners() const { return {lower_left_, - DualPoint(axis_type(), angle_type(), lower_left_.lambda(), upper_right_.mu()), + DualPoint<Real>(axis_type(), angle_type(), lower_left_.lambda(), upper_right_.mu()), upper_right_, - DualPoint(axis_type(), angle_type(), upper_right_.lambda(), lower_left_.mu())}; + DualPoint<Real>(axis_type(), angle_type(), upper_right_.lambda(), lower_left_.mu())}; } - std::vector<DualPoint> DualBox::push_change_points(const Point& p) const + template<class Real> + std::vector<DualPoint<Real>> DualBox<Real>::push_change_points(const Point<Real>& p) const { - std::vector<DualPoint> result; + std::vector<DualPoint<Real>> result; result.reserve(2); bool is_y_type = lower_left_.is_y_type(); @@ -38,13 +26,13 @@ namespace md { auto mu_from_lambda = [p, is_y_type, is_flat](Real lambda) { bool is_x_type = not is_y_type, is_steep = not is_flat; - if (is_y_type and is_flat) { + if (is_y_type && is_flat) { return p.y - lambda * p.x; - } else if (is_y_type and is_steep) { + } else if (is_y_type && is_steep) { return p.y - p.x / lambda; - } else if (is_x_type and is_flat) { + } else if (is_x_type && is_flat) { return p.x - p.y / lambda; - } else if (is_x_type and is_steep) { + } else if (is_x_type && is_steep) { return p.x - lambda * p.y; } // to shut up compiler warning @@ -53,13 +41,13 @@ namespace md { auto lambda_from_mu = [p, is_y_type, is_flat](Real mu) { bool is_x_type = not is_y_type, is_steep = not is_flat; - if (is_y_type and is_flat) { + if (is_y_type && is_flat) { return (p.y - mu) / p.x; - } else if (is_y_type and is_steep) { + } else if (is_y_type && is_steep) { return p.x / (p.y - mu); - } else if (is_x_type and is_flat) { + } else if (is_x_type && is_flat) { return p.y / (p.x - mu); - } else if (is_x_type and is_steep) { + } else if (is_x_type && is_steep) { return (p.x - mu) / p.y; } // to shut up compiler warning @@ -67,7 +55,7 @@ namespace md { }; // all inequalities below are strict: equality means it is a corner - // and critical_points() returns corners anyway + // && critical_points() returns corners anyway Real mu_intersect_min = mu_from_lambda(lambda_min()); @@ -99,22 +87,24 @@ namespace md { return result; } - std::vector<DualPoint> DualBox::critical_points(const Point& /*p*/) const + template<class Real> + std::vector<DualPoint<Real>> DualBox<Real>::critical_points(const Point<Real>& /*p*/) const { // maximal difference is attained at corners return corners(); -// std::vector<DualPoint> result; +// std::vector<DualPoint<Real>> result; // result.reserve(6); // for(auto dp : corners()) result.push_back(dp); // for(auto dp : push_change_points(p)) result.push_back(dp); // return result; } - std::vector<DualPoint> DualBox::random_points(int n) const + template<class Real> + std::vector<DualPoint<Real>> DualBox<Real>::random_points(int n) const { assert(n >= 0); std::mt19937_64 gen(1); - std::vector<DualPoint> result; + std::vector<DualPoint<Real>> result; result.reserve(n); std::uniform_real_distribution<Real> mu_distr(mu_min(), mu_max()); std::uniform_real_distribution<Real> lambda_distr(lambda_min(), lambda_max()); @@ -124,7 +114,8 @@ namespace md { return result; } - bool DualBox::sanity_check() const + template<class Real> + bool DualBox<Real>::sanity_check() const { lower_left_.sanity_check(); upper_right_.sanity_check(); @@ -144,51 +135,56 @@ namespace md { return true; } - std::vector<DualBox> DualBox::refine() const + template<class Real> + std::vector<DualBox<Real>> DualBox<Real>::refine() const { - std::vector<DualBox> result; + std::vector<DualBox<Real>> result; result.reserve(4); Real lambda_middle = (lower_left().lambda() + upper_right().lambda()) / 2.0; Real mu_middle = (lower_left().mu() + upper_right().mu()) / 2.0; - DualPoint refinement_center(axis_type(), angle_type(), lambda_middle, mu_middle); + DualPoint<Real> refinement_center(axis_type(), angle_type(), lambda_middle, mu_middle); result.emplace_back(lower_left_, refinement_center); - result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_middle, mu_min()), - DualPoint(axis_type(), angle_type(), lambda_max(), mu_middle)); + result.emplace_back(DualPoint<Real>(axis_type(), angle_type(), lambda_middle, mu_min()), + DualPoint<Real>(axis_type(), angle_type(), lambda_max(), mu_middle)); result.emplace_back(refinement_center, upper_right_); - result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_min(), mu_middle), - DualPoint(axis_type(), angle_type(), lambda_middle, mu_max())); + result.emplace_back(DualPoint<Real>(axis_type(), angle_type(), lambda_min(), mu_middle), + DualPoint<Real>(axis_type(), angle_type(), lambda_middle, mu_max())); return result; } - bool DualBox::operator==(const DualBox& other) const + template<class Real> + bool DualBox<Real>::operator==(const DualBox& other) const { - return lower_left() == other.lower_left() and + return lower_left() == other.lower_left() && upper_right() == other.upper_right(); } - bool DualBox::contains(const DualPoint& dp) const + template<class Real> + bool DualBox<Real>::contains(const DualPoint<Real>& dp) const { - return dp.angle_type() == angle_type() and dp.axis_type() == axis_type() and - mu_max() >= dp.mu() and - mu_min() <= dp.mu() and - lambda_min() <= dp.lambda() and + return dp.angle_type() == angle_type() && dp.axis_type() == axis_type() && + mu_max() >= dp.mu() && + mu_min() <= dp.mu() && + lambda_min() <= dp.lambda() && lambda_max() >= dp.lambda(); } - DualPoint DualBox::lower_right() const + template<class Real> + DualPoint<Real> DualBox<Real>::lower_right() const { - return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_max(), mu_min()); + return DualPoint<Real>(lower_left_.axis_type(), lower_left_.angle_type(), lambda_max(), mu_min()); } - DualPoint DualBox::upper_left() const + template<class Real> + DualPoint<Real> DualBox<Real>::upper_left() const { - return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_min(), mu_max()); + return DualPoint<Real>(lower_left_.axis_type(), lower_left_.angle_type(), lambda_min(), mu_max()); } } diff --git a/matching/include/dual_point.h b/matching/include/dual_point.h index db32f1a..8438860 100644 --- a/matching/include/dual_point.h +++ b/matching/include/dual_point.h @@ -1,12 +1,9 @@ -// -// Created by narn on 12.02.19. -// - #ifndef MATCHING_DISTANCE_DUAL_POINT_H #define MATCHING_DISTANCE_DUAL_POINT_H #include <vector> #include <ostream> +#include <tuple> #include "common_util.h" #include "box.h" @@ -25,9 +22,10 @@ namespace md { // so, e.g., line y = x has 4 different non-equal representation. // we are unlikely to ever need this, because 4 cases are // always treated separately. + template<class Real_> class DualPoint { public: - using Real = md::Real; + using Real = Real_; DualPoint() = default; @@ -56,7 +54,6 @@ namespace md { bool is_y_type() const { return axis_type_ == AxisType::y_type; } - friend std::ostream& operator<<(std::ostream& os, const DualPoint& dp); bool operator<(const DualPoint& rhs) const; AxisType axis_type() const { return axis_type_; } @@ -66,16 +63,16 @@ namespace md { // return true otherwise bool sanity_check() const; - Real weighted_push(Point p) const; - Point push(Point p) const; + Real weighted_push(Point<Real> p) const; + Point<Real> push(Point<Real> p) const; bool is_horizontal() const; bool is_vertical() const; - bool goes_below(Point p) const; - bool goes_above(Point p) const; + bool goes_below(Point<Real> p) const; + bool goes_above(Point<Real> p) const; - bool contains(Point p) const; + bool contains(Point<Real> p) const; Real x_slope() const; Real y_slope() const; @@ -98,9 +95,13 @@ namespace md { Real mu_ {-1.0}; }; - std::ostream& operator<<(std::ostream& os, const DualPoint& dp); + template<class Real> + std::ostream& operator<<(std::ostream& os, const DualPoint<Real>& dp); - DualPoint midpoint(DualPoint x, DualPoint y); + template<class Real> + DualPoint<Real> midpoint(DualPoint<Real> x, DualPoint<Real> y); }; +#include "dual_point.hpp" + #endif //MATCHING_DISTANCE_DUAL_POINT_H diff --git a/matching/src/dual_point.cpp b/matching/include/dual_point.hpp index 1c00b58..04e25f2 100644 --- a/matching/src/dual_point.cpp +++ b/matching/include/dual_point.hpp @@ -1,10 +1,6 @@ -#include <tuple> - -#include "dual_point.h" - namespace md { - std::ostream& operator<<(std::ostream& os, const AxisType& at) + inline std::ostream& operator<<(std::ostream& os, const AxisType& at) { if (at == AxisType::x_type) os << "x-type"; @@ -13,7 +9,7 @@ namespace md { return os; } - std::ostream& operator<<(std::ostream& os, const AngleType& at) + inline std::ostream& operator<<(std::ostream& os, const AngleType& at) { if (at == AngleType::flat) os << "flat"; @@ -22,7 +18,8 @@ namespace md { return os; } - std::ostream& operator<<(std::ostream& os, const DualPoint& dp) + template<class Real> + std::ostream& operator<<(std::ostream& os, const DualPoint<Real>& dp) { os << "Line(" << dp.axis_type() << ", "; os << dp.angle_type() << ", "; @@ -37,13 +34,15 @@ namespace md { return os; } - bool DualPoint::operator<(const DualPoint& rhs) const + template<class Real> + bool DualPoint<Real>::operator<(const DualPoint<Real>& rhs) const { return std::tie(axis_type_, angle_type_, lambda_, mu_) < std::tie(rhs.axis_type_, rhs.angle_type_, rhs.lambda_, rhs.mu_); } - DualPoint::DualPoint(AxisType axis_type, AngleType angle_type, Real lambda, Real mu) + template<class Real> + DualPoint<Real>::DualPoint(AxisType axis_type, AngleType angle_type, Real lambda, Real mu) : axis_type_(axis_type), angle_type_(angle_type), @@ -53,7 +52,8 @@ namespace md { assert(sanity_check()); } - bool DualPoint::sanity_check() const + template<class Real> + bool DualPoint<Real>::sanity_check() const { if (lambda_ < 0.0) throw std::runtime_error("Invalid line, negative lambda"); @@ -64,7 +64,8 @@ namespace md { return true; } - Real DualPoint::gamma() const + template<class Real> + Real DualPoint<Real>::gamma() const { if (is_steep()) return atan(Real(1.0) / lambda_); @@ -72,17 +73,19 @@ namespace md { return atan(lambda_); } - DualPoint midpoint(DualPoint x, DualPoint y) + template<class Real> + DualPoint<Real> midpoint(DualPoint<Real> x, DualPoint<Real> y) { assert(x.angle_type() == y.angle_type() and x.axis_type() == y.axis_type()); Real lambda_mid = (x.lambda() + y.lambda()) / 2; Real mu_mid = (x.mu() + y.mu()) / 2; - return DualPoint(x.axis_type(), x.angle_type(), lambda_mid, mu_mid); + return DualPoint<Real>(x.axis_type(), x.angle_type(), lambda_mid, mu_mid); } // return k in the line equation y = kx + b - Real DualPoint::y_slope() const + template<class Real> + Real DualPoint<Real>::y_slope() const { if (is_flat()) return lambda(); @@ -91,7 +94,8 @@ namespace md { } // return k in the line equation x = ky + b - Real DualPoint::x_slope() const + template<class Real> + Real DualPoint<Real>::x_slope() const { if (is_flat()) return Real(1.0) / lambda(); @@ -100,7 +104,8 @@ namespace md { } // return b in the line equation y = kx + b - Real DualPoint::y_intercept() const + template<class Real> + Real DualPoint<Real>::y_intercept() const { if (is_y_type()) { return mu(); @@ -112,7 +117,8 @@ namespace md { } // return k in the line equation x = ky + b - Real DualPoint::x_intercept() const + template<class Real> + Real DualPoint<Real>::x_intercept() const { if (is_x_type()) { return mu(); @@ -123,7 +129,8 @@ namespace md { } } - Real DualPoint::x_from_y(Real y) const + template<class Real> + Real DualPoint<Real>::x_from_y(Real y) const { if (is_horizontal()) throw std::runtime_error("x_from_y called on horizontal line"); @@ -131,7 +138,8 @@ namespace md { return x_slope() * y + x_intercept(); } - Real DualPoint::y_from_x(Real x) const + template<class Real> + Real DualPoint<Real>::y_from_x(Real x) const { if (is_vertical()) throw std::runtime_error("x_from_y called on horizontal line"); @@ -139,17 +147,20 @@ namespace md { return y_slope() * x + y_intercept(); } - bool DualPoint::is_horizontal() const + template<class Real> + bool DualPoint<Real>::is_horizontal() const { return is_flat() and lambda() == 0; } - bool DualPoint::is_vertical() const + template<class Real> + bool DualPoint<Real>::is_vertical() const { return is_steep() and lambda() == 0; } - - bool DualPoint::contains(Point p) const + + template<class Real> + bool DualPoint<Real>::contains(Point<Real> p) const { if (is_vertical()) return p.x == x_from_y(p.y); @@ -157,7 +168,8 @@ namespace md { return p.y == y_from_x(p.x); } - bool DualPoint::goes_below(Point p) const + template<class Real> + bool DualPoint<Real>::goes_below(Point<Real> p) const { if (is_vertical()) return p.x <= x_from_y(p.y); @@ -165,7 +177,8 @@ namespace md { return p.y >= y_from_x(p.x); } - bool DualPoint::goes_above(Point p) const + template<class Real> + bool DualPoint<Real>::goes_above(Point<Real> p) const { if (is_vertical()) return p.x >= x_from_y(p.y); @@ -173,9 +186,10 @@ namespace md { return p.y <= y_from_x(p.x); } - Point DualPoint::push(Point p) const + template<class Real> + Point<Real> DualPoint<Real>::push(Point<Real> p) const { - Point result; + Point<Real> result; // if line is below p, we push horizontally bool horizontal_push = goes_below(p); if (is_x_type()) { @@ -225,7 +239,8 @@ namespace md { return result; } - Real DualPoint::weighted_push(Point p) const + template<class Real> + Real DualPoint<Real>::weighted_push(Point<Real> p) const { // if line is below p, we push horizontally bool horizontal_push = goes_below(p); @@ -267,7 +282,8 @@ namespace md { } } - bool DualPoint::operator==(const DualPoint& other) const + template<class Real> + bool DualPoint<Real>::operator==(const DualPoint<Real>& other) const { return axis_type() == other.axis_type() and angle_type() == other.angle_type() and @@ -275,7 +291,8 @@ namespace md { lambda() == other.lambda(); } - Real DualPoint::weight() const + template<class Real> + Real DualPoint<Real>::weight() const { return lambda_ / sqrt(1 + lambda_ * lambda_); } diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h index bb10203..5be34c7 100644 --- a/matching/include/matching_distance.h +++ b/matching/include/matching_distance.h @@ -4,9 +4,10 @@ #include <limits> #include <utility> #include <ostream> +#include <chrono> +#include <tuple> +#include <algorithm> -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" #include "common_defs.h" #include "cell_with_value.h" @@ -17,12 +18,15 @@ #include "bifiltration.h" #include "bottleneck.h" -namespace spd = spdlog; - namespace md { - using HeatMap = std::map<DualPoint, Real>; - using HeatMaps = std::map<int, HeatMap>; +#ifdef MD_PRINT_HEAT_MAP + template<class Real> + using HeatMap = std::map<DualPoint<Real>, Real>; + + template<class Real> + using HeatMaps = std::map<int, HeatMap<Real>>; +#endif enum class BoundStrategy { bruteforce, @@ -39,18 +43,107 @@ namespace md { upper_bound }; - std::ostream& operator<<(std::ostream& os, const BoundStrategy& s); - - std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s); - - std::istream& operator>>(std::istream& is, BoundStrategy& s); - - std::istream& operator>>(std::istream& is, TraverseStrategy& s); - - BoundStrategy bs_from_string(std::string s); - - TraverseStrategy ts_from_string(std::string s); - + inline std::ostream& operator<<(std::ostream& os, const BoundStrategy& s) + { + switch(s) { + case BoundStrategy::bruteforce : + os << "bruteforce"; + break; + case BoundStrategy::local_dual_bound : + os << "local_grob"; + break; + case BoundStrategy::local_combined : + os << "local_combined"; + break; + case BoundStrategy::local_dual_bound_refined : + os << "local_refined"; + break; + case BoundStrategy::local_dual_bound_for_each_point : + os << "local_for_each_point"; + break; + default: + os << "FORGOTTEN BOUND STRATEGY"; + } + return os; + } + + inline std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s) + { + switch(s) { + case TraverseStrategy::depth_first : + os << "DFS"; + break; + case TraverseStrategy::breadth_first : + os << "BFS"; + break; + case TraverseStrategy::breadth_first_value : + os << "BFS-VAL"; + break; + case TraverseStrategy::upper_bound : + os << "UB"; + break; + default: + os << "FORGOTTEN TRAVERSE STRATEGY"; + } + return os; + } + + inline std::istream& operator>>(std::istream& is, TraverseStrategy& s) + { + std::string ss; + is >> ss; + if (ss == "DFS") { + s = TraverseStrategy::depth_first; + } else if (ss == "BFS") { + s = TraverseStrategy::breadth_first; + } else if (ss == "BFS-VAL") { + s = TraverseStrategy::breadth_first_value; + } else if (ss == "UB") { + s = TraverseStrategy::upper_bound; + } else { + throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY"); + } + return is; + } + + + inline std::istream& operator>>(std::istream& is, BoundStrategy& s) + { + std::string ss; + is >> ss; + if (ss == "bruteforce") { + s = BoundStrategy::bruteforce; + } else if (ss == "local_grob") { + s = BoundStrategy::local_dual_bound; + } else if (ss == "local_combined") { + s = BoundStrategy::local_combined; + } else if (ss == "local_refined") { + s = BoundStrategy::local_dual_bound_refined; + } else if (ss == "local_for_each_point") { + s = BoundStrategy::local_dual_bound_for_each_point; + } else { + throw std::runtime_error("UNKNOWN BOUND STRATEGY"); + } + return is; + } + + inline BoundStrategy bs_from_string(std::string s) + { + std::stringstream ss(s); + BoundStrategy result; + ss >> result; + return result; + } + + inline TraverseStrategy ts_from_string(std::string s) + { + std::stringstream ss(s); + TraverseStrategy result; + ss >> result; + return result; + } + + template<class Real> struct CalculationParams { static constexpr int ALL_DIMENSIONS = -1; @@ -75,22 +168,22 @@ namespace md { // print statistics on each quad-tree level bool print_stats { false }; -#ifdef PRINT_HEAT_MAP +#ifdef MD_PRINT_HEAT_MAP HeatMaps heat_maps; #endif }; - template<class DiagramProvider> + template<class Real_, class DiagramProvider> class DistanceCalculator { - using DualBox = md::DualBox; - using CellValueVector = std::vector<CellWithValue>; + using Real = Real_; + using CellValueVector = std::vector<CellWithValue<Real>>; public: DistanceCalculator(const DiagramProvider& a, const DiagramProvider& b, - CalculationParams& params); + CalculationParams<Real>& params); Real distance(); @@ -100,7 +193,7 @@ namespace md { DiagramProvider module_a_; DiagramProvider module_b_; - CalculationParams& params_; + CalculationParams<Real>& params_; int n_hera_calls_; std::map<int, int> n_hera_calls_per_level_; @@ -112,65 +205,83 @@ namespace md { CellValueVector get_initial_dual_grid(Real& lower_bound); +#ifdef MD_PRINT_HEAT_MAP void heatmap_in_dimension(int dim, int depth); +#endif Real get_max_x(int module) const; Real get_max_y(int module) const; - void set_cell_central_value(CellWithValue& dual_cell); + void set_cell_central_value(CellWithValue<Real>& dual_cell); Real get_distance(); Real get_distance_pq(); - // temporary, to try priority queue - Real get_max_possible_value(const CellWithValue* first_cell_ptr, int n_cells); + Real get_max_possible_value(const CellWithValue<Real>* first_cell_ptr, int n_cells); - Real get_upper_bound(const CellWithValue& dual_cell, Real good_enough_upper_bound) const; + Real get_upper_bound(const CellWithValue<Real>& dual_cell, Real good_enough_upper_bound) const; - Real get_single_dgm_bound(const CellWithValue& dual_cell, ValuePoint vp, int module, + Real get_single_dgm_bound(const CellWithValue<Real>& dual_cell, ValuePoint vp, int module, Real good_enough_value) const; // this bound depends only on dual box - Real get_local_dual_bound(int module, const DualBox& dual_box) const; + Real get_local_dual_bound(int module, const DualBox<Real>& dual_box) const; - Real get_local_dual_bound(const DualBox& dual_box) const; + Real get_local_dual_bound(const DualBox<Real>& dual_box) const; // this bound depends only on dual box, is more accurate - Real get_local_refined_bound(int module, const md::DualBox& dual_box) const; + Real get_local_refined_bound(int module, const DualBox<Real>& dual_box) const; - Real get_local_refined_bound(const md::DualBox& dual_box) const; + Real get_local_refined_bound(const DualBox<Real>& dual_box) const; Real get_good_enough_upper_bound(Real lower_bound) const; - Real - get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint value_point, const Point& p) const; + Real get_max_displacement_single_point(const CellWithValue<Real>& dual_cell, ValuePoint value_point, + const Point<Real>& p) const; - void check_upper_bound(const CellWithValue& dual_cell) const; + void check_upper_bound(const CellWithValue<Real>& dual_cell) const; - Real distance_on_line(DualPoint line); - Real distance_on_line_const(DualPoint line) const; + Real distance_on_line(DualPoint<Real> line); + Real distance_on_line_const(DualPoint<Real> line) const; Real current_error(Real lower_bound, Real upper_bound); }; - Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, CalculationParams& params); + template<class Real> + Real matching_distance(const Bifiltration<Real>& bif_a, const Bifiltration<Real>& bif_b, + CalculationParams<Real>& params); - Real matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b, CalculationParams& params); + template<class Real> + Real matching_distance(const ModulePresentation<Real>& mod_a, const ModulePresentation<Real>& mod_b, + CalculationParams<Real>& params); // for upper bound experiment struct UbExperimentRecord { - Real error; - Real lower_bound; - Real upper_bound; - CellWithValue cell; + double error; + double lower_bound; + double upper_bound; + CellWithValue<double> cell; long long int time; long long int n_hera_calls; }; - std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r); + inline std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r) + { + os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound; + return os; + } + + + template<class K, class V> + void print_map(const std::map<K, V>& dic) + { + for(const auto kv : dic) { + fmt::print("{} -> {}\n", kv.first, kv.second); + } + } -} +} // namespace md #include "matching_distance.hpp" diff --git a/matching/include/matching_distance.hpp b/matching/include/matching_distance.hpp index d2d2fbc..48c8464 100644 --- a/matching/include/matching_distance.hpp +++ b/matching/include/matching_distance.hpp @@ -1,34 +1,26 @@ namespace md { - template<class K, class V> - void print_map(const std::map<K, V>& dic) - { - for(const auto kv : dic) { - fmt::print("{} -> {}\n", kv.first, kv.second); - } - } - - template<class T> - void DistanceCalculator<T>::check_upper_bound(const CellWithValue& dual_cell) const + template<class R, class T> + void DistanceCalculator<R, T>::check_upper_bound(const CellWithValue<R>& dual_cell) const { spd::debug("Enter check_get_max_delta_on_cell"); const int n_samples_lambda = 100; const int n_samples_mu = 100; - DualBox db = dual_cell.dual_box(); - Real min_lambda = db.lambda_min(); - Real max_lambda = db.lambda_max(); - Real min_mu = db.mu_min(); - Real max_mu = db.mu_max(); - - Real h_lambda = (max_lambda - min_lambda) / n_samples_lambda; - Real h_mu = (max_mu - min_mu) / n_samples_mu; + DualBox<R> db = dual_cell.dual_box(); + R min_lambda = db.lambda_min(); + R max_lambda = db.lambda_max(); + R min_mu = db.mu_min(); + R max_mu = db.mu_max(); + + R h_lambda = (max_lambda - min_lambda) / n_samples_lambda; + R h_mu = (max_mu - min_mu) / n_samples_mu; for(int i = 1; i < n_samples_lambda; ++i) { for(int j = 1; j < n_samples_mu; ++j) { - Real lambda = min_lambda + i * h_lambda; - Real mu = min_mu + j * h_mu; - DualPoint l(db.axis_type(), db.angle_type(), lambda, mu); - Real other_result = distance_on_line_const(l); - Real diff = fabs(dual_cell.stored_upper_bound() - other_result); + R lambda = min_lambda + i * h_lambda; + R mu = min_mu + j * h_mu; + DualPoint<R> l(db.axis_type(), db.angle_type(), lambda, mu); + R other_result = distance_on_line_const(l); + R diff = fabs(dual_cell.stored_upper_bound() - other_result); if (other_result > dual_cell.stored_upper_bound()) { spd::error( "in check_upper_bound, upper_bound = {}, other_result = {}, diff = {}\ndual_cell = {}", @@ -42,10 +34,10 @@ namespace md { // for all lines l, l' inside dual box, // find the upper bound on the difference of weighted pushes of p - template<class T> - Real - DistanceCalculator<T>::get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint vp, - const Point& p) const + template<class R, class T> + R + DistanceCalculator<R, T>::get_max_displacement_single_point(const CellWithValue<R>& dual_cell, ValuePoint vp, + const Point<R>& p) const { assert(p.x >= 0 && p.y >= 0); @@ -53,15 +45,15 @@ namespace md { std::vector<long long int> debug_ids = {3, 13, 54, 218, 350, 382, 484, 795, 2040, 8415, 44076}; bool debug = false; // std::find(debug_ids.begin(), debug_ids.end(), dual_cell.id) != debug_ids.end(); #endif - DualPoint line = dual_cell.value_point(vp); - const Real base_value = line.weighted_push(p); + DualPoint<R> line = dual_cell.value_point(vp); + const R base_value = line.weighted_push(p); spd::debug("Enter get_max_displacement_single_point, p = {},\ndual_cell = {},\nline = {}, base_value = {}\n", p, dual_cell, line, base_value); - Real result = 0.0; - for(DualPoint dp : dual_cell.dual_box().critical_points(p)) { - Real dp_value = dp.weighted_push(p); + R result = 0.0; + for(DualPoint<R> dp : dual_cell.dual_box().critical_points(p)) { + R dp_value = dp.weighted_push(p); spd::debug( "In get_max_displacement_single_point, p = {}, critical dp = {},\ndp_value = {}, diff = {},\ndual_cell = {}\n", p, dp, dp_value, fabs(base_value - dp_value), dual_cell); @@ -69,15 +61,15 @@ namespace md { } #ifdef MD_DO_FULL_CHECK - DualBox db = dual_cell.dual_box(); - std::uniform_real_distribution<Real> dlambda(db.lambda_min(), db.lambda_max()); - std::uniform_real_distribution<Real> dmu(db.mu_min(), db.mu_max()); + auto db = dual_cell.dual_box(); + std::uniform_real_distribution<R> dlambda(db.lambda_min(), db.lambda_max()); + std::uniform_real_distribution<R> dmu(db.mu_min(), db.mu_max()); std::mt19937 gen(1); for(int i = 0; i < 1000; ++i) { - Real lambda = dlambda(gen); - Real mu = dmu(gen); - DualPoint dp_random { db.axis_type(), db.angle_type(), lambda, mu }; - Real dp_value = dp_random.weighted_push(p); + R lambda = dlambda(gen); + R mu = dmu(gen); + DualPoint<R> dp_random { db.axis_type(), db.angle_type(), lambda, mu }; + R dp_value = dp_random.weighted_push(p); if (fabs(base_value - dp_value) > result) { spd::error("in get_max_displacement_single_point, p = {}, vp = {}\ndb = {}\nresult = {}, base_value = {}, dp_value = {}, dp_random = {}", p, vp, db, result, base_value, dp_value, dp_random); @@ -89,12 +81,12 @@ namespace md { return result; } - template<class T> - typename DistanceCalculator<T>::CellValueVector DistanceCalculator<T>::get_initial_dual_grid(Real& lower_bound) + template<class R, class T> + typename DistanceCalculator<R, T>::CellValueVector DistanceCalculator<R, T>::get_initial_dual_grid(R& lower_bound) { CellValueVector result = get_refined_grid(params_.initialization_depth, false, true); - lower_bound = -1.0; + lower_bound = -1; for(const auto& dc : result) { lower_bound = std::max(lower_bound, dc.max_corner_value()); } @@ -102,8 +94,8 @@ namespace md { assert(lower_bound >= 0); for(auto& dual_cell : result) { - Real good_enough_ub = get_good_enough_upper_bound(lower_bound); - Real max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub); + R good_enough_ub = get_good_enough_upper_bound(lower_bound); + R max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub); dual_cell.set_max_possible_value(max_value_on_cell); #ifdef MD_DO_FULL_CHECK @@ -116,39 +108,39 @@ namespace md { return result; } - template<class T> - typename DistanceCalculator<T>::CellValueVector - DistanceCalculator<T>::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last) + template<class R, class T> + typename DistanceCalculator<R, T>::CellValueVector + DistanceCalculator<R, T>::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last) { - const Real y_max = std::max(module_a_.max_y(), module_b_.max_y()); - const Real x_max = std::max(module_a_.max_x(), module_b_.max_x()); + const R y_max = std::max(module_a_.max_y(), module_b_.max_y()); + const R x_max = std::max(module_a_.max_x(), module_b_.max_x()); - const Real lambda_min = 0; - const Real lambda_max = 1; + const R lambda_min = 0; + const R lambda_max = 1; - const Real mu_min = 0; + const R mu_min = 0; - DualBox x_flat(DualPoint(AxisType::x_type, AngleType::flat, lambda_min, mu_min), - DualPoint(AxisType::x_type, AngleType::flat, lambda_max, x_max)); + DualBox<R> x_flat(DualPoint<R>(AxisType::x_type, AngleType::flat, lambda_min, mu_min), + DualPoint<R>(AxisType::x_type, AngleType::flat, lambda_max, x_max)); - DualBox x_steep(DualPoint(AxisType::x_type, AngleType::steep, lambda_min, mu_min), - DualPoint(AxisType::x_type, AngleType::steep, lambda_max, x_max)); + DualBox<R> x_steep(DualPoint<R>(AxisType::x_type, AngleType::steep, lambda_min, mu_min), + DualPoint<R>(AxisType::x_type, AngleType::steep, lambda_max, x_max)); - DualBox y_flat(DualPoint(AxisType::y_type, AngleType::flat, lambda_min, mu_min), - DualPoint(AxisType::y_type, AngleType::flat, lambda_max, y_max)); + DualBox<R> y_flat(DualPoint<R>(AxisType::y_type, AngleType::flat, lambda_min, mu_min), + DualPoint<R>(AxisType::y_type, AngleType::flat, lambda_max, y_max)); - DualBox y_steep(DualPoint(AxisType::y_type, AngleType::steep, lambda_min, mu_min), - DualPoint(AxisType::y_type, AngleType::steep, lambda_max, y_max)); + DualBox<R> y_steep(DualPoint<R>(AxisType::y_type, AngleType::steep, lambda_min, mu_min), + DualPoint<R>(AxisType::y_type, AngleType::steep, lambda_max, y_max)); - CellWithValue x_flat_cell(x_flat, 0); - CellWithValue x_steep_cell(x_steep, 0); - CellWithValue y_flat_cell(y_flat, 0); - CellWithValue y_steep_cell(y_steep, 0); + CellWithValue<R> x_flat_cell(x_flat, 0); + CellWithValue<R> x_steep_cell(x_steep, 0); + CellWithValue<R> y_flat_cell(y_flat, 0); + CellWithValue<R> y_steep_cell(y_steep, 0); if (init_depth == 0) { - DualPoint diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0); + DualPoint<R> diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0); - Real diagonal_value = distance_on_line(diagonal_x_flat); + R diagonal_value = distance_on_line(diagonal_x_flat); n_hera_calls_per_level_[0]++; x_flat_cell.set_value_at(ValuePoint::lower_right, diagonal_value); @@ -162,7 +154,7 @@ namespace md { x_steep_cell.id = 2; y_flat_cell.id = 3; y_steep_cell.id = 4; - CellWithValue::max_id = 4; + CellWithValue<R>::max_id = 4; #endif CellValueVector result {x_flat_cell, x_steep_cell, y_flat_cell, y_steep_cell}; @@ -189,10 +181,10 @@ namespace md { return result; } - template<class T> - DistanceCalculator<T>::DistanceCalculator(const T& a, + template<class R, class T> + DistanceCalculator<R, T>::DistanceCalculator(const T& a, const T& b, - CalculationParams& params) + CalculationParams<R>& params) : module_a_(a), module_b_(b), @@ -213,33 +205,33 @@ namespace md { module_a_.max_x(), module_a_.max_y(), module_b_.max_x(), module_b_.max_y()); } - template<class T> - Real DistanceCalculator<T>::get_max_x(int module) const + template<class R, class T> + R DistanceCalculator<R, T>::get_max_x(int module) const { return (module == 0) ? module_a_.max_x() : module_b_.max_x(); } - template<class T> - Real DistanceCalculator<T>::get_max_y(int module) const + template<class R, class T> + R DistanceCalculator<R, T>::get_max_y(int module) const { return (module == 0) ? module_a_.max_y() : module_b_.max_y(); } - template<class T> - Real - DistanceCalculator<T>::get_local_refined_bound(const md::DualBox& dual_box) const + template<class R, class T> + R + DistanceCalculator<R, T>::get_local_refined_bound(const DualBox<R>& dual_box) const { return get_local_refined_bound(0, dual_box) + get_local_refined_bound(1, dual_box); } - template<class T> - Real - DistanceCalculator<T>::get_local_refined_bound(int module, const md::DualBox& dual_box) const + template<class R, class T> + R + DistanceCalculator<R, T>::get_local_refined_bound(int module, const DualBox<R>& dual_box) const { spd::debug("Enter get_local_refined_bound, dual_box = {}", dual_box); - Real d_lambda = dual_box.lambda_max() - dual_box.lambda_min(); - Real d_mu = dual_box.mu_max() - dual_box.mu_min(); - Real result; + R d_lambda = dual_box.lambda_max() - dual_box.lambda_min(); + R d_mu = dual_box.mu_max() - dual_box.mu_min(); + R result; if (dual_box.axis_type() == AxisType::x_type) { if (dual_box.is_flat()) { result = dual_box.lambda_max() * d_mu + (get_max_x(module) - dual_box.mu_min()) * d_lambda; @@ -258,11 +250,11 @@ namespace md { return result; } - template<class T> - Real DistanceCalculator<T>::get_local_dual_bound(int module, const md::DualBox& dual_box) const + template<class R, class T> + R DistanceCalculator<R, T>::get_local_dual_bound(int module, const DualBox<R>& dual_box) const { - Real dlambda = dual_box.lambda_max() - dual_box.lambda_min(); - Real dmu = dual_box.mu_max() - dual_box.mu_min(); + R dlambda = dual_box.lambda_max() - dual_box.lambda_min(); + R dmu = dual_box.mu_max() - dual_box.mu_min(); if (dual_box.is_flat()) { return get_max_x(module) * dlambda + dmu; @@ -271,20 +263,20 @@ namespace md { } } - template<class T> - Real DistanceCalculator<T>::get_local_dual_bound(const md::DualBox& dual_box) const + template<class R, class T> + R DistanceCalculator<R, T>::get_local_dual_bound(const DualBox<R>& dual_box) const { return get_local_dual_bound(0, dual_box) + get_local_dual_bound(1, dual_box); } - template<class T> - Real DistanceCalculator<T>::get_upper_bound(const CellWithValue& dual_cell, Real good_enough_ub) const + template<class R, class T> + R DistanceCalculator<R, T>::get_upper_bound(const CellWithValue<R>& dual_cell, R good_enough_ub) const { assert(good_enough_ub >= 0); switch(params_.bound_strategy) { case BoundStrategy::bruteforce: - return std::numeric_limits<Real>::max(); + return std::numeric_limits<R>::max(); case BoundStrategy::local_dual_bound: return dual_cell.min_value() + get_local_dual_bound(dual_cell.dual_box()); @@ -293,7 +285,7 @@ namespace md { return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); case BoundStrategy::local_combined: { - Real cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); + R cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); if (cheap_upper_bound < good_enough_ub) { return cheap_upper_bound; } else { @@ -302,14 +294,14 @@ namespace md { } case BoundStrategy::local_dual_bound_for_each_point: { - Real result = std::numeric_limits<Real>::max(); + R result = std::numeric_limits<R>::max(); for(ValuePoint vp : k_corner_vps) { if (not dual_cell.has_value_at(vp)) { continue; } - Real base_value = dual_cell.value_at(vp); - Real bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub); + R base_value = dual_cell.value_at(vp); + R bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub); if (params_.stop_asap and bound_dgm_a + base_value >= good_enough_ub) { // we want to return a valid upper bound, not just something that will prevent discarding the cell @@ -318,8 +310,8 @@ namespace md { return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); } - Real bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1, - std::max(Real(0), good_enough_ub - bound_dgm_a)); + R bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1, + std::max(R(0), good_enough_ub - bound_dgm_a)); result = std::min(result, base_value + bound_dgm_a + bound_dgm_b); @@ -336,19 +328,19 @@ namespace md { } } // to suppress compiler warning - return std::numeric_limits<Real>::max(); + return std::numeric_limits<R>::max(); } // find maximal displacement of weighted points of m for all lines in dual_box - template<class T> - Real - DistanceCalculator<T>::get_single_dgm_bound(const CellWithValue& dual_cell, + template<class R, class T> + R + DistanceCalculator<R, T>::get_single_dgm_bound(const CellWithValue<R>& dual_cell, ValuePoint vp, int module, - [[maybe_unused]] Real good_enough_value) const + R good_enough_value) const { - Real result = 0; - Point max_point; + R result = 0; + Point<R> max_point; spd::debug( "Enter get_single_dgm_bound, module = {}, dual_cell = {}, vp = {}, good_enough_value = {}, stop_asap = {}\n", @@ -358,7 +350,7 @@ namespace md { for(const auto& position : m.positions()) { spd::debug("in get_single_dgm_bound, simplex = {}\n", position); - Real x = get_max_displacement_single_point(dual_cell, vp, position); + R x = get_max_displacement_single_point(dual_cell, vp, position); spd::debug("In get_single_dgm_bound, point = {}, displacement = {}", position, x); @@ -385,30 +377,30 @@ namespace md { return result; } - template<class T> - Real DistanceCalculator<T>::distance() + template<class R, class T> + R DistanceCalculator<R, T>::distance() { return get_distance_pq(); } // calculate weighted bottleneneck distance between slices on line // increments hera calls counter - template<class T> - Real DistanceCalculator<T>::distance_on_line(DualPoint line) + template<class R, class T> + R DistanceCalculator<R, T>::distance_on_line(DualPoint<R> line) { ++n_hera_calls_; - Real result = distance_on_line_const(line); + R result = distance_on_line_const(line); return result; } - template<class T> - Real DistanceCalculator<T>::distance_on_line_const(DualPoint line) const + template<class R, class T> + R DistanceCalculator<R, T>::distance_on_line_const(DualPoint<R> line) const { // TODO: think about this - how to call Hera auto dgm_a = module_a_.weighted_slice_diagram(line); auto dgm_b = module_b_.weighted_slice_diagram(line); - Real result; - if (params_.hera_epsilon > static_cast<Real>(0)) { + R result; + if (params_.hera_epsilon > static_cast<R>(0)) { result = hera::bottleneckDistApprox(dgm_a, dgm_b, params_.hera_epsilon) / ( params_.hera_epsilon + 1); } else { result = hera::bottleneckDistExact(dgm_a, dgm_b); @@ -423,10 +415,10 @@ namespace md { return result; } - template<class T> - Real DistanceCalculator<T>::get_good_enough_upper_bound(Real lower_bound) const + template<class R, class T> + R DistanceCalculator<R, T>::get_good_enough_upper_bound(R lower_bound) const { - Real result; + R result; // in upper_bound strategy we only prune cells if they cannot improve the lower bound, // otherwise the experiment is supposed to run indefinitely if (params_.traverse_strategy == TraverseStrategy::upper_bound) { @@ -440,14 +432,14 @@ namespace md { // helper function // calculate weighted bt distance on cell center, // assign distance value to cell, keep it in heat_map, and return - template<class T> - void DistanceCalculator<T>::set_cell_central_value(CellWithValue& dual_cell) + template<class R, class T> + void DistanceCalculator<R, T>::set_cell_central_value(CellWithValue<R>& dual_cell) { - DualPoint central_line {dual_cell.center()}; + DualPoint<R> central_line {dual_cell.center()}; spd::debug("In set_cell_central_value, processing dual cell = {}, line = {}", dual_cell.dual_box(), central_line); - Real new_value = distance_on_line(central_line); + R new_value = distance_on_line(central_line); n_hera_calls_per_level_[dual_cell.level() + 1]++; dual_cell.set_value_at(ValuePoint::center, new_value); params_.actual_max_depth = std::max(params_.actual_max_depth, dual_cell.level() + 1); @@ -472,10 +464,10 @@ namespace md { // assumes that the underlying container is vector! // cell_ptr: pointer to the first element in queue // n_cells: queue size - template<class T> - Real DistanceCalculator<T>::get_max_possible_value(const CellWithValue* cell_ptr, int n_cells) + template<class R, class T> + R DistanceCalculator<R, T>::get_max_possible_value(const CellWithValue<R>* cell_ptr, int n_cells) { - Real result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0; + R result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0; for(int i = 0; i < n_cells; ++i, ++cell_ptr) { result = std::max(result, cell_ptr->stored_upper_bound()); } @@ -485,11 +477,11 @@ namespace md { // helper function: // return current error from lower and upper bounds // and save it in params_ (hence not const) - template<class T> - Real DistanceCalculator<T>::current_error(Real lower_bound, Real upper_bound) + template<class R, class T> + R DistanceCalculator<R, T>::current_error(R lower_bound, R upper_bound) { - Real current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound - : std::numeric_limits<Real>::max(); + R current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound + : std::numeric_limits<R>::max(); params_.actual_error = current_error; @@ -505,8 +497,8 @@ namespace md { // use priority queue to store dual cells // comparison function depends on the strategies in params_ // ressets hera calls counter - template<class T> - Real DistanceCalculator<T>::get_distance_pq() + template<class R, class T> + R DistanceCalculator<R, T>::get_distance_pq() { std::map<int, long> n_cells_considered; std::map<int, long> n_cells_pushed_into_queue; @@ -527,26 +519,26 @@ namespace md { // if cell is too deep and is not pushed into queue, // we still need to take its max value into account; // the max over such cells is stored in max_result_on_too_fine_cells - Real upper_bound_on_deep_cells = -1; + R upper_bound_on_deep_cells = -1; spd::debug("Started iterations in dual space, delta = {}, bound_strategy = {}", params_.delta, params_.bound_strategy); // user-defined less lambda function // to regulate priority queue depending on strategy - auto dual_cell_less = [this](const CellWithValue& a, const CellWithValue& b) { + auto dual_cell_less = [this](const CellWithValue<R>& a, const CellWithValue<R>& b) { int a_level = a.level(); int b_level = b.level(); - Real a_value = a.max_corner_value(); - Real b_value = b.max_corner_value(); - Real a_ub = a.stored_upper_bound(); - Real b_ub = b.stored_upper_bound(); + R a_value = a.max_corner_value(); + R b_value = b.max_corner_value(); + R a_ub = a.stored_upper_bound(); + R b_ub = b.stored_upper_bound(); if (this->params_.traverse_strategy == TraverseStrategy::upper_bound and (not a.has_max_possible_value() or not b.has_max_possible_value())) { throw std::runtime_error("no upper bound on cell"); } - DualPoint a_lower_left = a.dual_box().lower_left(); - DualPoint b_lower_left = b.dual_box().lower_left(); + DualPoint<R> a_lower_left = a.dual_box().lower_left(); + DualPoint<R> b_lower_left = b.dual_box().lower_left(); switch(this->params_.traverse_strategy) { // in both breadth_first searches we want coarser cells @@ -569,24 +561,24 @@ namespace md { } }; - std::priority_queue<CellWithValue, CellValueVector, decltype(dual_cell_less)> dual_cells_queue( + std::priority_queue<CellWithValue<R>, CellValueVector, decltype(dual_cell_less)> dual_cells_queue( dual_cell_less); // weighted bt distance on the center of current cell - Real lower_bound = std::numeric_limits<Real>::min(); + R lower_bound = std::numeric_limits<R>::min(); // init pq and lower bound for(auto& init_cell : get_initial_dual_grid(lower_bound)) { dual_cells_queue.push(init_cell); } - Real upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size()); + R upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size()); std::vector<UbExperimentRecord> ub_experiment_results; while(not dual_cells_queue.empty()) { - CellWithValue dual_cell = dual_cells_queue.top(); + CellWithValue<R> dual_cell = dual_cells_queue.top(); dual_cells_queue.pop(); assert(dual_cell.has_corner_value() and dual_cell.has_max_possible_value() @@ -620,7 +612,7 @@ namespace md { // until now, dual_cell knows its value in one of its corners // new_value will be the weighted distance at its center set_cell_central_value(dual_cell); - Real new_value = dual_cell.value_at(ValuePoint::center); + R new_value = dual_cell.value_at(ValuePoint::center); lower_bound = std::max(new_value, lower_bound); spd::debug("Processed cell = {}, weighted value = {}, lower_bound = {}", dual_cell, new_value, lower_bound); @@ -638,11 +630,11 @@ namespace md { throw std::runtime_error("no value on cell"); // if delta is smaller than good_enough_value, it allows to prune cell - Real good_enough_ub = get_good_enough_upper_bound(lower_bound); + R good_enough_ub = get_good_enough_upper_bound(lower_bound); // upper bound of the parent holds for refined_cell // and can sometimes be smaller! - Real upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(), + R upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(), get_upper_bound(refined_cell, good_enough_ub)); spd::debug("upper_bound_on_refined_cell = {}, dual_cell.stored_upper_bound = {}, get_upper_bound = {}", @@ -774,10 +766,46 @@ namespace md { return lower_bound; } - template<class T> - int DistanceCalculator<T>::get_hera_calls_number() const + template<class R, class T> + int DistanceCalculator<R, T>::get_hera_calls_number() const { return n_hera_calls_; } -}
\ No newline at end of file + template<class R> + R matching_distance(const Bifiltration<R>& bif_a, const Bifiltration<R>& bif_b, + CalculationParams<R>& params) + { + R result; + // compute distance only in one dimension + if (params.dim != CalculationParams<R>::ALL_DIMENSIONS) { + BifiltrationProxy<R> bifp_a(bif_a, params.dim); + BifiltrationProxy<R> bifp_b(bif_b, params.dim); + DistanceCalculator<R, BifiltrationProxy<R>> runner(bifp_a, bifp_b, params); + result = runner.distance(); + params.n_hera_calls = runner.get_hera_calls_number(); + } else { + // compute distance in all dimensions, return maximal + result = -1; + for(int dim = 0; dim < std::max(bif_a.maximal_dim(), bif_b.maximal_dim()); ++dim) { + BifiltrationProxy<R> bifp_a(bif_a, params.dim); + BifiltrationProxy<R> bifp_b(bif_a, params.dim); + DistanceCalculator<R, BifiltrationProxy<R>> runner(bifp_a, bifp_b, params); + result = std::max(result, runner.distance()); + params.n_hera_calls += runner.get_hera_calls_number(); + } + } + return result; + } + + + template<class R> + R matching_distance(const ModulePresentation<R>& mod_a, const ModulePresentation<R>& mod_b, + CalculationParams<R>& params) + { + DistanceCalculator<R, ModulePresentation<R>> runner(mod_a, mod_b, params); + R result = runner.distance(); + params.n_hera_calls = runner.get_hera_calls_number(); + return result; + } +} // namespace md diff --git a/matching/include/persistence_module.h b/matching/include/persistence_module.h index a1fc67e..e99771f 100644 --- a/matching/include/persistence_module.h +++ b/matching/include/persistence_module.h @@ -5,6 +5,12 @@ #include <vector> #include <utility> #include <string> +#include <numeric> +#include <algorithm> +#include <unordered_set> + +#include "phat/boundary_matrix.h" +#include "phat/compute_persistence_pairs.h" #include "common_util.h" #include "dual_point.h" @@ -28,17 +34,20 @@ namespace md { */ + template<class Real> class ModulePresentation { public: + using RealVec = std::vector<Real>; + enum Format { rivet_firep }; struct Relation { - Point position_; + Point<Real> position_; IndexVec components_; Relation() {} - Relation(const Point& _pos, const IndexVec& _components); + Relation(const Point<Real>& _pos, const IndexVec& _components); Real get_x() const { return position_.x; } Real get_y() const { return position_.y; } @@ -48,9 +57,9 @@ namespace md { ModulePresentation() {} - ModulePresentation(const PointVec& _generators, const RelVec& _relations); + ModulePresentation(const PointVec<Real>& _generators, const RelVec& _relations); - Diagram weighted_slice_diagram(const DualPoint& line) const; + Diagram<Real> weighted_slice_diagram(const DualPoint<Real>& line) const; // translate all points by vector (a,a) void translate(Real a); @@ -59,9 +68,7 @@ namespace md { Real minimal_coordinate() const { return std::min(min_x(), min_y()); } // return box that contains all positions of all simplices - Box bounding_box() const; - - friend std::ostream& operator<<(std::ostream& os, const ModulePresentation& mp); + Box<Real> bounding_box() const; Real max_x() const { return max_x_; } @@ -71,26 +78,27 @@ namespace md { Real min_y() const { return min_y_; } - PointVec positions() const; + PointVec<Real> positions() const; private: - PointVec generators_; + PointVec<Real> generators_; std::vector<Relation> relations_; - PointVec positions_; + PointVec<Real> positions_; Real max_x_ { std::numeric_limits<Real>::max() }; Real max_y_ { std::numeric_limits<Real>::max() }; Real min_x_ { -std::numeric_limits<Real>::max() }; Real min_y_ { -std::numeric_limits<Real>::max() }; - Box bounding_box_; + Box<Real> bounding_box_; void init_boundaries(); - void project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; - void project_relations(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; + void project_generators(const DualPoint<Real>& slice, IndexVec& sorted_indices, RealVec& projections) const; + void project_relations(const DualPoint<Real>& slice, IndexVec& sorted_indices, RealVec& projections) const; }; } // namespace md +#include "persistence_module.hpp" #endif //MATCHING_DISTANCE_PERSISTENCE_MODULE_H diff --git a/matching/src/persistence_module.cpp b/matching/include/persistence_module.hpp index efb20ef..6e49b2e 100644 --- a/matching/src/persistence_module.cpp +++ b/matching/include/persistence_module.hpp @@ -1,12 +1,3 @@ -#include <numeric> -#include <algorithm> -#include <unordered_set> - -#include <phat/boundary_matrix.h> -#include <phat/compute_persistence_pairs.h> - -#include "persistence_module.h" - namespace md { /** @@ -17,7 +8,7 @@ namespace md { * 2) a_1,...,a_n is a permutation of 1,..,n */ - template<typename T> + template<class T> IndexVec get_sorted_indices(const std::vector<T>& values) { IndexVec result(values.size()); @@ -28,18 +19,20 @@ namespace md { } // helper function to initialize const member positions_ in ModulePresentation - PointVec - concat_gen_and_rel_positions(const PointVec& generators, const ModulePresentation::RelVec& relations) + template<class Real> + PointVec<Real> concat_gen_and_rel_positions(const PointVec<Real>& generators, + const typename ModulePresentation<Real>::RelVec& relations) { - std::unordered_set<Point> ps(generators.begin(), generators.end()); + std::unordered_set<Point<Real>> ps(generators.begin(), generators.end()); for(const auto& rel : relations) { ps.insert(rel.position_); } - return PointVec(ps.begin(), ps.end()); + return PointVec<Real>(ps.begin(), ps.end()); } - void ModulePresentation::init_boundaries() + template<class Real> + void ModulePresentation<Real>::init_boundaries() { max_x_ = std::numeric_limits<Real>::max(); max_y_ = std::numeric_limits<Real>::max(); @@ -53,18 +46,20 @@ namespace md { max_y_ = std::max(gen.y, max_y_); } - bounding_box_ = Box(Point(min_x_, min_y_), Point(max_x_, max_y_)); + bounding_box_ = Box<Real>(Point<Real>(min_x_, min_y_), Point<Real>(max_x_, max_y_)); } - ModulePresentation::ModulePresentation(const PointVec& _generators, const RelVec& _relations) : + template<class Real> + ModulePresentation<Real>::ModulePresentation(const PointVec<Real>& _generators, const RelVec& _relations) : generators_(_generators), relations_(_relations) { init_boundaries(); } - void ModulePresentation::translate(md::Real a) + template<class Real> + void ModulePresentation<Real>::translate(Real a) { for(auto& g : generators_) { g.translate(a); @@ -86,8 +81,9 @@ namespace md { * @param projections sorted weighted pushes of generators */ - void - ModulePresentation::project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const + template<class Real> + void ModulePresentation<Real>::project_generators(const DualPoint<Real>& slice, + IndexVec& sorted_indices, RealVec& projections) const { size_t num_gens = generators_.size(); @@ -104,7 +100,8 @@ namespace md { } } - void ModulePresentation::project_relations(const DualPoint& slice, IndexVec& sorted_rel_indices, + template<class Real> + void ModulePresentation<Real>::project_relations(const DualPoint<Real>& slice, IndexVec& sorted_rel_indices, RealVec& projections) const { size_t num_rels = relations_.size(); @@ -122,7 +119,8 @@ namespace md { } } - Diagram ModulePresentation::weighted_slice_diagram(const DualPoint& slice) const + template<class Real> + Diagram<Real> ModulePresentation<Real>::weighted_slice_diagram(const DualPoint<Real>& slice) const { IndexVec sorted_gen_indices, sorted_rel_indices; RealVec gen_projections, rel_projections; @@ -147,7 +145,7 @@ namespace md { phat::persistence_pairs phat_persistence_pairs; phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix); - Diagram dgm; + Diagram<Real> dgm; constexpr Real real_inf = std::numeric_limits<Real>::infinity(); @@ -164,14 +162,16 @@ namespace md { return dgm; } - PointVec ModulePresentation::positions() const + template<class Real> + PointVec<Real> ModulePresentation<Real>::positions() const { return positions_; } - Box ModulePresentation::bounding_box() const + template<class Real> + Box<Real> ModulePresentation<Real>::bounding_box() const { return bounding_box_; } -} +} // namespace md diff --git a/matching/include/simplex.h b/matching/include/simplex.h index e9d0e30..75bbcae 100644 --- a/matching/include/simplex.h +++ b/matching/include/simplex.h @@ -9,6 +9,7 @@ namespace md { + template<class Real> class Bifiltration; enum class BifiltrationFormat { @@ -38,11 +39,21 @@ namespace md { int dim() const { return vertices_.size() - 1; } - void push_back(int v); + void push_back(int v) + { + vertices_.push_back(v); + std::sort(vertices_.begin(), vertices_.end()); + } AbstractSimplex() { } - AbstractSimplex(std::vector<int> vertices, bool sort = true); + AbstractSimplex(std::vector<int> vertices, bool sort = true) + :vertices_(vertices) + { + if (sort) + std::sort(vertices_.begin(), vertices_.end()); + } + template<class Iter> AbstractSimplex(Iter beg_iter, Iter end_iter, bool sort = true) @@ -53,22 +64,51 @@ namespace md { std::sort(vertices_.begin(), end()); } - std::vector<AbstractSimplex> facets() const; + std::vector<AbstractSimplex> facets() const + { + std::vector<AbstractSimplex> result; + for (int i = 0; i < static_cast<int>(vertices_.size()); ++i) { + std::vector<int> facet_vertices; + facet_vertices.reserve(dim()); + for (int j = 0; j < static_cast<int>(vertices_.size()); ++j) { + if (j != i) + facet_vertices.push_back(vertices_[j]); + } + if (!facet_vertices.empty()) { + result.emplace_back(facet_vertices, false); + } + } + return result; + } friend std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s); - // compare by vertices_ only friend bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2); friend bool operator<(const AbstractSimplex&, const AbstractSimplex&); }; - std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s); + inline std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s) + { + os << "AbstractSimplex(id = " << s.id << ", vertices_ = " << container_to_string(s.vertices_) << ")"; + return os; + } + + inline bool operator<(const AbstractSimplex& a, const AbstractSimplex& b) + { + return a.vertices_ < b.vertices_; + } + + inline bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2) + { + return s1.vertices_ == s2.vertices_; + } + template<class Real> class Simplex { private: Index id_; - Point pos_; + Point<Real> pos_; int dim_; // in our format we use facet indices, // this is the fastest representation for homology @@ -77,11 +117,11 @@ namespace md { // conversion routines are in Bifiltration Column facet_indices_; Column vertices_; - Real v {0.0}; // used when constructed a filtration for a slice + Real v {0}; // used when constructed a filtration for a slice public: Simplex(Index _id, std::string s, BifiltrationFormat input_format); - Simplex(Index _id, Point birth, int _dim, const Column& _bdry); + Simplex(Index _id, Point<Real> birth, int _dim, const Column& _bdry); void init_rivet(std::string s); @@ -96,9 +136,9 @@ namespace md { Real value() const { return v; } // assumes 1-criticality - Point position() const { return pos_; } + Point<Real> position() const { return pos_; } - void set_position(const Point& new_pos) { pos_ = new_pos; } + void set_position(const Point<Real>& new_pos) { pos_ = new_pos; } void scale(Real lambda) { @@ -110,12 +150,14 @@ namespace md { void set_value(Real new_val) { v = new_val; } - friend std::ostream& operator<<(std::ostream& os, const Simplex& s); - - friend Bifiltration; + friend Bifiltration<Real>; }; - std::ostream& operator<<(std::ostream& os, const Simplex& s); + template<class Real> + std::ostream& operator<<(std::ostream& os, const Simplex<Real>& s); } + +#include "simplex.hpp" + #endif //MATCHING_DISTANCE_SIMPLEX_H diff --git a/matching/include/simplex.hpp b/matching/include/simplex.hpp new file mode 100644 index 0000000..ce0e30f --- /dev/null +++ b/matching/include/simplex.hpp @@ -0,0 +1,79 @@ +namespace md { + + template<class Real> + Simplex<Real>::Simplex(Index id, Point<Real> birth, int dim, const Column& bdry) + : + id_(id), + pos_(birth), + dim_(dim), + facet_indices_(bdry) { } + + template<class Real> + void Simplex<Real>::translate(Real a) + { + pos_.translate(a); + } + + template<class Real> + void Simplex<Real>::init_rivet(std::string s) + { + auto delim_pos = s.find_first_of(";"); + assert(delim_pos > 0); + std::string vertices_str = s.substr(0, delim_pos); + std::string pos_str = s.substr(delim_pos + 1); + assert(not vertices_str.empty() and not pos_str.empty()); + // get vertices + std::stringstream vertices_ss(vertices_str); + int dim = 0; + int vertex; + while (vertices_ss >> vertex) { + dim++; + vertices_.push_back(vertex); + } + // + std::sort(vertices_.begin(), vertices_.end()); + assert(dim > 0); + + std::stringstream pos_ss(pos_str); + // TODO: get rid of 1-criticaltiy assumption + pos_ss >> pos_.x >> pos_.y; + } + + template<class Real> + void Simplex<Real>::init_phat_like(std::string s) + { + facet_indices_.clear(); + std::stringstream ss(s); + ss >> dim_ >> pos_.x >> pos_.y; + if (dim_ > 0) { + facet_indices_.reserve(dim_ + 1); + for (int j = 0; j <= dim_; j++) { + Index k; + ss >> k; + facet_indices_.push_back(k); + } + } + } + + template<class Real> + Simplex<Real>::Simplex(Index _id, std::string s, BifiltrationFormat input_format) + :id_(_id) + { + switch (input_format) { + case BifiltrationFormat::phat_like : + init_phat_like(s); + break; + case BifiltrationFormat::rivet : + init_rivet(s); + break; + } + } + + template<class Real> + std::ostream& operator<<(std::ostream& os, const Simplex<Real>& x) + { + os << "Simplex<Real>(id = " << x.id() << ", dim = " << x.dim(); + os << ", boundary = " << container_to_string(x.boundary()) << ", pos = " << x.position() << ")"; + return os; + } +} diff --git a/matching/src/box.cpp b/matching/src/box.cpp deleted file mode 100644 index c128698..0000000 --- a/matching/src/box.cpp +++ /dev/null @@ -1,61 +0,0 @@ - -#include "box.h" - -namespace md { - - std::ostream& operator<<(std::ostream& os, const Box& box) - { - os << "Box(lower_left = " << box.lower_left() << ", upper_right = " << box.upper_right() << ")"; - return os; - } - - Box get_enclosing_box(const Box& box_a, const Box& box_b) - { - Point lower_left(std::min(box_a.lower_left().x, box_b.lower_left().x), - std::min(box_a.lower_left().y, box_b.lower_left().y)); - Point upper_right(std::max(box_a.upper_right().x, box_b.upper_right().x), - std::max(box_a.upper_right().y, box_b.upper_right().y)); - return Box(lower_left, upper_right); - } - - void Box::translate(md::Real a) - { - ll.x += a; - ll.y += a; - ur.x += a; - ur.y += a; - } - - std::vector<Box> Box::refine() const - { - std::vector<Box> result; - -// 1 | 2 -// 0 | 3 - - Point new_ll = lower_left(); - Point new_ur = center(); - result.emplace_back(new_ll, new_ur); - - new_ll.y = center().y; - new_ur.y = ur.y; - result.emplace_back(new_ll, new_ur); - - new_ll = center(); - new_ur = upper_right(); - result.emplace_back(new_ll, new_ur); - - new_ll.y = ll.y; - new_ur.y = center().y; - result.emplace_back(new_ll, new_ur); - - return result; - } - - std::vector<Point> Box::corners() const - { - return {ll, Point(ll.x, ur.y), ur, Point(ur.x, ll.y)}; - }; - - -} diff --git a/matching/src/common_util.cpp b/matching/src/common_util.cpp deleted file mode 100644 index 96c3388..0000000 --- a/matching/src/common_util.cpp +++ /dev/null @@ -1,243 +0,0 @@ -#include <vector> -#include <utility> -#include <cmath> -#include <ostream> -#include <limits> -#include <algorithm> - -#include <common_util.h> - -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" - -namespace md { - - - int gcd(int a, int b) - { - assert(a != 0 or b != 0); - // make b <= a - std::tie(b, a) = std::minmax({ abs(a), abs(b) }); - if (b == 0) - return a; - while((a = a % b)) { - std::swap(a, b); - } - return b; - } - - int signum(int a) - { - if (a < 0) - return -1; - else if (a > 0) - return 1; - else - return 0; - } - - Rational reduce(Rational frac) - { - int d = gcd(frac.numerator, frac.denominator); - frac.numerator /= d; - frac.denominator /= d; - return frac; - } - - void Rational::reduce() { *this = md::reduce(*this); } - - - Rational& Rational::operator*=(const md::Rational& rhs) - { - numerator *= rhs.numerator; - denominator *= rhs.denominator; - reduce(); - return *this; - } - - Rational& Rational::operator/=(const md::Rational& rhs) - { - numerator *= rhs.denominator; - denominator *= rhs.numerator; - reduce(); - return *this; - } - - Rational& Rational::operator+=(const md::Rational& rhs) - { - numerator = numerator * rhs.denominator + denominator * rhs.numerator; - denominator *= rhs.denominator; - reduce(); - return *this; - } - - Rational& Rational::operator-=(const md::Rational& rhs) - { - numerator = numerator * rhs.denominator - denominator * rhs.numerator; - denominator *= rhs.denominator; - reduce(); - return *this; - } - - - Rational midpoint(Rational a, Rational b) - { - return reduce({a.numerator * b.denominator + a.denominator * b.numerator, 2 * a.denominator * b.denominator }); - } - - Rational operator+(Rational a, const Rational& b) - { - a += b; - return a; - } - - Rational operator-(Rational a, const Rational& b) - { - a -= b; - return a; - } - - Rational operator*(Rational a, const Rational& b) - { - a *= b; - return a; - } - - Rational operator/(Rational a, const Rational& b) - { - a /= b; - return a; - } - - bool is_less(Rational a, Rational b) - { - // compute a - b = a_1 / a_2 - b_1 / b_2 - long numer = a.numerator * b.denominator - a.denominator * b.numerator; - long denom = a.denominator * b.denominator; - assert(denom != 0); - return signum(numer) * signum(denom) < 0; - } - - bool operator==(const Rational& a, const Rational& b) - { - return std::tie(a.numerator, a.denominator) == std::tie(b.numerator, b.denominator); - } - - bool operator<(const Rational& a, const Rational& b) - { - // do not remove signum - overflow - long numer = a.numerator * b.denominator - a.denominator * b.numerator; - long denom = a.denominator * b.denominator; - assert(denom != 0); -// spdlog::debug("a = {}, b = {}, numer = {}, denom = {}, result = {}", a, b, numer, denom, signum(numer) * signum(denom) <= 0); - return signum(numer) * signum(denom) < 0; - } - - bool is_leq(Rational a, Rational b) - { - // compute a - b = a_1 / a_2 - b_1 / b_2 - long numer = a.numerator * b.denominator - a.denominator * b.numerator; - long denom = a.denominator * b.denominator; - assert(denom != 0); - return signum(numer) * signum(denom) <= 0; - } - - bool is_greater(Rational a, Rational b) - { - return not is_leq(a, b); - } - - bool is_geq(Rational a, Rational b) - { - return not is_less(a, b); - } - - Point operator+(const Point& u, const Point& v) - { - return Point(u.x + v.x, u.y + v.y); - } - - Point operator-(const Point& u, const Point& v) - { - return Point(u.x - v.x, u.y - v.y); - } - - Point least_upper_bound(const Point& u, const Point& v) - { - return Point(std::max(u.x, v.x), std::max(u.y, v.y)); - } - - Point greatest_lower_bound(const Point& u, const Point& v) - { - return Point(std::min(u.x, v.x), std::min(u.y, v.y)); - } - - Point max_point() - { - return Point(std::numeric_limits<Real>::max(), std::numeric_limits<Real>::min()); - } - - Point min_point() - { - return Point(-std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::min()); - } - - std::ostream& operator<<(std::ostream& ostr, const Point& vec) - { - ostr << "(" << vec.x << ", " << vec.y << ")"; - return ostr; - } - - Real l_infty_norm(const Point& v) - { - return std::max(std::abs(v.x), std::abs(v.y)); - } - - Real l_2_norm(const Point& v) - { - return v.norm(); - } - - Real l_2_dist(const Point& x, const Point& y) - { - return l_2_norm(x - y); - } - - Real l_infty_dist(const Point& x, const Point& y) - { - return l_infty_norm(x - y); - } - - void DiagramKeeper::add_point(int dim, md::Real birth, md::Real death) - { - data_[dim].emplace_back(birth, death); - } - - DiagramKeeper::Diagram DiagramKeeper::get_diagram(int dim) const - { - if (data_.count(dim) == 1) - return data_.at(dim); - else - return DiagramKeeper::Diagram(); - } - - // return true, if line starts with # - // or contains only spaces - bool ignore_line(const std::string& s) - { - for(auto c : s) { - if (isspace(c)) - continue; - return (c == '#'); - } - return true; - } - - - - std::ostream& operator<<(std::ostream& os, const Rational& a) - { - os << a.numerator << " / " << a.denominator; - return os; - } -} diff --git a/matching/src/main.cpp b/matching/src/main.cpp index f1472be..2093457 100644 --- a/matching/src/main.cpp +++ b/matching/src/main.cpp @@ -18,12 +18,20 @@ #include "box.h"
#include "matching_distance.h"
+using Real = double;
+
using namespace md;
namespace fs = std::experimental::filesystem;
+void force_instantiation()
+{
+ DualBox<Real> db;
+ std::cout << db;
+}
+
#ifdef PRINT_HEAT_MAP
-void print_heat_map(const md::HeatMaps& hms, std::string fname, const CalculationParams& params)
+void print_heat_map(const md::HeatMaps<Real>& hms, std::string fname, const CalculationParams<Real>& params)
{
spd::debug("Entered print_heat_map");
std::set<Real> mu_vals, lambda_vals;
@@ -143,7 +151,7 @@ int main(int argc, char** argv) bool help = false;
bool no_stop_asap = false;
- CalculationParams params;
+ CalculationParams<Real> params;
#ifdef PRINT_HEAT_MAP
bool heatmap_only = false;
@@ -178,8 +186,8 @@ int main(int argc, char** argv) auto bounds_list = split_by_delim(bounds_list_str, ',');
auto traverse_list = split_by_delim(traverse_list_str, ',');
- Bifiltration bif_a(fname_a);
- Bifiltration bif_b(fname_b);
+ Bifiltration<Real> bif_a(fname_a);
+ Bifiltration<Real> bif_b(fname_b);
bif_a.sanity_check();
bif_b.sanity_check();
@@ -207,11 +215,11 @@ int main(int argc, char** argv) }
struct ExperimentResult {
- CalculationParams params {CalculationParams()};
+ CalculationParams<Real> params {CalculationParams()};
int n_hera_calls {0};
double total_milliseconds_elapsed {0};
- double distance {0};
- double actual_error {std::numeric_limits<double>::max()};
+ Real distance {0};
+ Real actual_error {std::numeric_limits<double>::max()};
int actual_max_depth {0};
int x_wins {0};
@@ -250,7 +258,7 @@ int main(int argc, char** argv) ExperimentResult() { }
- ExperimentResult(CalculationParams p, int nhc, double tme, double d)
+ ExperimentResult(CalculationParams<Real> p, int nhc, double tme, double d)
:
params(p), n_hera_calls(nhc), total_milliseconds_elapsed(tme), distance(d) { }
};
@@ -267,7 +275,7 @@ int main(int argc, char** argv) std::map<std::tuple<BoundStrategy, TraverseStrategy>, ExperimentResult> results;
for(BoundStrategy bound_strategy : bound_strategies) {
for(TraverseStrategy traverse_strategy : traverse_strategies) {
- CalculationParams params_experiment;
+ CalculationParams<Real> params_experiment;
params_experiment.bound_strategy = bound_strategy;
params_experiment.traverse_strategy = traverse_strategy;
params_experiment.max_depth = params.max_depth;
@@ -366,8 +374,9 @@ int main(int argc, char** argv) spd::debug("Will use {} bound, {} traverse strategy", params.bound_strategy, params.traverse_strategy);
- Real dist = matching_distance(bif_a, bif_b, params);
+ Real dist = matching_distance<Real>(bif_a, bif_b, params);
std::cout << dist << std::endl;
#endif
+ force_instantiation();
return 0;
}
diff --git a/matching/src/matching_distance.cpp b/matching/src/matching_distance.cpp deleted file mode 100644 index e53233f..0000000 --- a/matching/src/matching_distance.cpp +++ /dev/null @@ -1,150 +0,0 @@ -#include <chrono> -#include <tuple> -#include <algorithm> - -#include "common_defs.h" - -#include "matching_distance.h" - -namespace md { - - Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, - CalculationParams& params) - { - Real result; - // compute distance only in one dimension - if (params.dim != CalculationParams::ALL_DIMENSIONS) { - BifiltrationProxy bifp_a(bif_a, params.dim); - BifiltrationProxy bifp_b(bif_b, params.dim); - DistanceCalculator<BifiltrationProxy> runner(bifp_a, bifp_b, params); - result = runner.distance(); - params.n_hera_calls = runner.get_hera_calls_number(); - } else { - // compute distance in all dimensions, return maximal - result = -1; - for(int dim = 0; dim < std::max(bif_a.maximal_dim(), bif_b.maximal_dim()); ++dim) { - BifiltrationProxy bifp_a(bif_a, params.dim); - BifiltrationProxy bifp_b(bif_a, params.dim); - DistanceCalculator<BifiltrationProxy> runner(bifp_a, bifp_b, params); - result = std::max(result, runner.distance()); - params.n_hera_calls += runner.get_hera_calls_number(); - } - } - return result; - } - - - Real matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b, - CalculationParams& params) - { - DistanceCalculator<ModulePresentation> runner(mod_a, mod_b, params); - Real result = runner.distance(); - params.n_hera_calls = runner.get_hera_calls_number(); - return result; - } - - std::istream& operator>>(std::istream& is, BoundStrategy& s) - { - std::string ss; - is >> ss; - if (ss == "bruteforce") { - s = BoundStrategy::bruteforce; - } else if (ss == "local_grob") { - s = BoundStrategy::local_dual_bound; - } else if (ss == "local_combined") { - s = BoundStrategy::local_combined; - } else if (ss == "local_refined") { - s = BoundStrategy::local_dual_bound_refined; - } else if (ss == "local_for_each_point") { - s = BoundStrategy::local_dual_bound_for_each_point; - } else { - throw std::runtime_error("UNKNOWN BOUND STRATEGY"); - } - return is; - } - - BoundStrategy bs_from_string(std::string s) - { - std::stringstream ss(s); - BoundStrategy result; - ss >> result; - return result; - } - - TraverseStrategy ts_from_string(std::string s) - { - std::stringstream ss(s); - TraverseStrategy result; - ss >> result; - return result; - } - - std::istream& operator>>(std::istream& is, TraverseStrategy& s) - { - std::string ss; - is >> ss; - if (ss == "DFS") { - s = TraverseStrategy::depth_first; - } else if (ss == "BFS") { - s = TraverseStrategy::breadth_first; - } else if (ss == "BFS-VAL") { - s = TraverseStrategy::breadth_first_value; - } else if (ss == "UB") { - s = TraverseStrategy::upper_bound; - } else { - throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY"); - } - return is; - } - - std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r) - { - os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound; - return os; - } - - std::ostream& operator<<(std::ostream& os, const BoundStrategy& s) - { - switch(s) { - case BoundStrategy::bruteforce : - os << "bruteforce"; - break; - case BoundStrategy::local_dual_bound : - os << "local_grob"; - break; - case BoundStrategy::local_combined : - os << "local_combined"; - break; - case BoundStrategy::local_dual_bound_refined : - os << "local_refined"; - break; - case BoundStrategy::local_dual_bound_for_each_point : - os << "local_for_each_point"; - break; - default: - os << "FORGOTTEN BOUND STRATEGY"; - } - return os; - } - - std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s) - { - switch(s) { - case TraverseStrategy::depth_first : - os << "DFS"; - break; - case TraverseStrategy::breadth_first : - os << "BFS"; - break; - case TraverseStrategy::breadth_first_value : - os << "BFS-VAL"; - break; - case TraverseStrategy::upper_bound : - os << "UB"; - break; - default: - os << "FORGOTTEN TRAVERSE STRATEGY"; - } - return os; - } -} diff --git a/matching/src/simplex.cpp b/matching/src/simplex.cpp deleted file mode 100644 index 6b53680..0000000 --- a/matching/src/simplex.cpp +++ /dev/null @@ -1,121 +0,0 @@ -#include "simplex.h" - -namespace md { - - std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s) - { - os << "AbstractSimplex(id = " << s.id << ", vertices_ = " << container_to_string(s.vertices_) << ")"; - return os; - } - - bool operator<(const AbstractSimplex& a, const AbstractSimplex& b) - { - return a.vertices_ < b.vertices_; - } - - bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2) - { - return s1.vertices_ == s2.vertices_; - } - - void AbstractSimplex::push_back(int v) - { - vertices_.push_back(v); - std::sort(vertices_.begin(), vertices_.end()); - } - - AbstractSimplex::AbstractSimplex(std::vector<int> vertices, bool sort) - :vertices_(vertices) - { - if (sort) - std::sort(vertices_.begin(), vertices_.end()); - } - - std::vector<AbstractSimplex> AbstractSimplex::facets() const - { - std::vector<AbstractSimplex> result; - for (int i = 0; i < static_cast<int>(vertices_.size()); ++i) { - std::vector<int> facet_vertices; - facet_vertices.reserve(dim()); - for (int j = 0; j < static_cast<int>(vertices_.size()); ++j) { - if (j != i) - facet_vertices.push_back(vertices_[j]); - } - if (!facet_vertices.empty()) { - result.emplace_back(facet_vertices, false); - } - } - return result; - } - - Simplex::Simplex(md::Index id, md::Point birth, int dim, const md::Column& bdry) - : - id_(id), - pos_(birth), - dim_(dim), - facet_indices_(bdry) { } - - void Simplex::translate(Real a) - { - pos_.translate(a); - } - - void Simplex::init_rivet(std::string s) - { - auto delim_pos = s.find_first_of(";"); - assert(delim_pos > 0); - std::string vertices_str = s.substr(0, delim_pos); - std::string pos_str = s.substr(delim_pos + 1); - assert(not vertices_str.empty() and not pos_str.empty()); - // get vertices - std::stringstream vertices_ss(vertices_str); - int dim = 0; - int vertex; - while (vertices_ss >> vertex) { - dim++; - vertices_.push_back(vertex); - } - // - std::sort(vertices_.begin(), vertices_.end()); - assert(dim > 0); - - std::stringstream pos_ss(pos_str); - // TODO: get rid of 1-criticaltiy assumption - pos_ss >> pos_.x >> pos_.y; - } - - void Simplex::init_phat_like(std::string s) - { - facet_indices_.clear(); - std::stringstream ss(s); - ss >> dim_ >> pos_.x >> pos_.y; - if (dim_ > 0) { - facet_indices_.reserve(dim_ + 1); - for (int j = 0; j <= dim_; j++) { - Index k; - ss >> k; - facet_indices_.push_back(k); - } - } - } - - Simplex::Simplex(Index _id, std::string s, BifiltrationFormat input_format) - :id_(_id) - { - switch (input_format) { - case BifiltrationFormat::phat_like : - init_phat_like(s); - break; - case BifiltrationFormat::rivet : - init_rivet(s); - break; - } - } - - std::ostream& operator<<(std::ostream& os, const Simplex& x) - { - os << "Simplex(id = " << x.id() << ", dim = " << x.dim(); - os << ", boundary = " << container_to_string(x.boundary()) << ", pos = " << x.position() << ")"; - return os; - } -} diff --git a/matching/src/test_generator.cpp b/matching/src/test_generator.cpp index e8f128f..a2f0625 100644 --- a/matching/src/test_generator.cpp +++ b/matching/src/test_generator.cpp @@ -11,9 +11,12 @@ #include "common_util.h" #include "bifiltration.h" +using Real = double; using Index = md::Index; -using Point = md::Point; +using Point = md::Point<Real>; +using Bifiltration = md::Bifiltration<Real>; using Column = md::Column; +using Simplex = md::Simplex<Real>; int g_max_coord = 100; @@ -100,7 +103,7 @@ void generate_positions(const ASimplex& s, ASimplexToBirthMap& simplex_to_birth, } } -md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_simplices) +Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_simplices) { ASimplexToBirthMap simplex_to_birth; @@ -122,13 +125,13 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_ add_if_top(candidate_simplex, top_simplices); } - Point upper_bound{static_cast<md::Real>(g_max_coord), static_cast<md::Real>(g_max_coord)}; + Point upper_bound{static_cast<Real>(g_max_coord), static_cast<Real>(g_max_coord)}; for(const auto& top_simplex : top_simplices) { generate_positions(top_simplex, simplex_to_birth, upper_bound); } std::vector<std::pair<ASimplex, Point>> simplex_birth_pairs{simplex_to_birth.begin(), simplex_to_birth.end()}; - std::vector<md::Column> boundaries{simplex_to_birth.size(), md::Column()}; + std::vector<Column> boundaries{simplex_to_birth.size(), Column()}; // assign ids and save boundaries int id = 0; @@ -138,7 +141,7 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_ ASimplex& simplex = simplex_birth_pairs[i].first; if (simplex.dim() == dim) { simplex.id = id++; - md::Column bdry; + Column bdry; for(auto& facet : simplex.facets()) { auto facet_iter = std::find_if(simplex_birth_pairs.begin(), simplex_birth_pairs.end(), [facet](const std::pair<ASimplex, Point>& sbp) { return facet == sbp.first; }); @@ -153,7 +156,7 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_ } // create vector of Simplex-es - std::vector<md::Simplex> simplices; + std::vector<Simplex> simplices; for(int i = 0; i < (int) simplex_birth_pairs.size(); ++i) { int id = simplex_birth_pairs[i].first.id; int dim = simplex_birth_pairs[i].first.dim(); @@ -164,13 +167,13 @@ md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_ // sort by id std::sort(simplices.begin(), simplices.end(), - [](const md::Simplex& s1, const md::Simplex& s2) { return s1.id() < s2.id(); }); + [](const Simplex& s1, const Simplex& s2) { return s1.id() < s2.id(); }); for(int i = 0; i < (int)simplices.size(); ++i) { assert(simplices[i].id() == i); assert(i == 0 || simplices[i].dim() >= simplices[i - 1].dim()); } - return md::Bifiltration(simplices.begin(), simplices.end()); + return Bifiltration(simplices.begin(), simplices.end()); } int main(int argc, char** argv) diff --git a/matching/src/tests/test_common.cpp b/matching/src/tests/test_common.cpp index c55577e..9079a56 100644 --- a/matching/src/tests/test_common.cpp +++ b/matching/src/tests/test_common.cpp @@ -8,56 +8,24 @@ #include "simplex.h" #include "matching_distance.h" -using namespace md; +//using namespace md; +using Real = double; +using Point = md::Point<Real>; +using Bifiltration = md::Bifiltration<Real>; +using BifiltrationProxy = md::BifiltrationProxy<Real>; +using CalculationParams = md::CalculationParams<Real>; +using CellWithValue = md::CellWithValue<Real>; +using DualPoint = md::DualPoint<Real>; +using DualBox = md::DualBox<Real>; +using Simplex = md::Simplex<Real>; +using AbstractSimplex = md::AbstractSimplex; +using BoundStrategy = md::BoundStrategy; +using TraverseStrategy = md::TraverseStrategy; +using AxisType = md::AxisType; +using AngleType = md::AngleType; +using ValuePoint = md::ValuePoint; +using Column = md::Column; -TEST_CASE("Rational", "[common_utils][rational]") -{ - // gcd - REQUIRE(gcd(10, 5) == 5); - REQUIRE(gcd(5, 10) == 5); - REQUIRE(gcd(5, 7) == 1); - REQUIRE(gcd(7, 5) == 1); - REQUIRE(gcd(13, 0) == 13); - REQUIRE(gcd(0, 13) == 13); - REQUIRE(gcd(16, 24) == 8); - REQUIRE(gcd(24, 16) == 8); - REQUIRE(gcd(16, 32) == 16); - REQUIRE(gcd(32, 16) == 16); - - - // reduce - REQUIRE(reduce({2, 1}) == std::make_pair(2, 1)); - REQUIRE(reduce({1, 2}) == std::make_pair(1, 2)); - REQUIRE(reduce({2, 2}) == std::make_pair(1, 1)); - REQUIRE(reduce({0, 2}) == std::make_pair(0, 1)); - REQUIRE(reduce({0, 20}) == std::make_pair(0, 1)); - REQUIRE(reduce({35, 49}) == std::make_pair(5, 7)); - REQUIRE(reduce({35, 25}) == std::make_pair(7, 5)); - - // midpoint - REQUIRE(midpoint(Rational {0, 1}, Rational {1, 2}) == std::make_pair(1, 4)); - REQUIRE(midpoint(Rational {1, 4}, Rational {1, 2}) == std::make_pair(3, 8)); - REQUIRE(midpoint(Rational {1, 2}, Rational {1, 2}) == std::make_pair(1, 2)); - REQUIRE(midpoint(Rational {1, 2}, Rational {1, 1}) == std::make_pair(3, 4)); - REQUIRE(midpoint(Rational {3, 7}, Rational {5, 14}) == std::make_pair(11, 28)); - - - // arithmetic - - REQUIRE(Rational(1, 2) + Rational(3, 5) == Rational(11, 10)); - REQUIRE(Rational(2, 5) - Rational(3, 10) == Rational(1, 10)); - REQUIRE(Rational(2, 3) * Rational(4, 7) == Rational(8, 21)); - REQUIRE(Rational(2, 3) * Rational(3, 2) == Rational(1)); - REQUIRE(Rational(2, 3) / Rational(3, 2) == Rational(4, 9)); - REQUIRE(Rational(1, 2) * Rational(3, 5) == Rational(3, 10)); - - // comparison - REQUIRE(Rational(100000, 2000000) < Rational(100001, 2000000)); - REQUIRE(!(Rational(100001, 2000000) < Rational(100000, 2000000))); - REQUIRE(!(Rational(100000, 2000000) < Rational(100000, 2000000))); - REQUIRE(Rational(-100000, 2000000) < Rational(100001, 2000000)); - REQUIRE(Rational(-100001, 2000000) < Rational(100000, 2000000)); -}; TEST_CASE("AbstractSimplex", "[abstract_simplex]") { diff --git a/matching/src/tests/test_matching_distance.cpp b/matching/src/tests/test_matching_distance.cpp index df9345e..82da530 100644 --- a/matching/src/tests/test_matching_distance.cpp +++ b/matching/src/tests/test_matching_distance.cpp @@ -11,7 +11,25 @@ #include "simplex.h" #include "matching_distance.h" -using namespace md; +using Real = double; +using Point = md::Point<Real>; +using Bifiltration = md::Bifiltration<Real>; +using BifiltrationProxy = md::BifiltrationProxy<Real>; +using CalculationParams = md::CalculationParams<Real>; +using CellWithValue = md::CellWithValue<Real>; +using DualPoint = md::DualPoint<Real>; +using DualBox = md::DualBox<Real>; +using Simplex = md::Simplex<Real>; +using AbstractSimplex = md::AbstractSimplex; +using BoundStrategy = md::BoundStrategy; +using TraverseStrategy = md::TraverseStrategy; +using AxisType = md::AxisType; +using AngleType = md::AngleType; +using ValuePoint = md::ValuePoint; +using Column = md::Column; + +using md::k_corner_vps; + namespace spd = spdlog; TEST_CASE("Different bounds", "[bounds]") @@ -40,7 +58,7 @@ TEST_CASE("Different bounds", "[bounds]") BifiltrationProxy bifp_a(bif_a, params.dim); BifiltrationProxy bifp_b(bif_b, params.dim); - DistanceCalculator<BifiltrationProxy> calc(bifp_a, bifp_b, params); + md::DistanceCalculator<Real, BifiltrationProxy> calc(bifp_a, bifp_b, params); // REQUIRE(calc.max_x_ == Approx(max_x)); // REQUIRE(calc.max_y_ == Approx(max_y)); |