From 3809e4071827a5959f27e472514eaed08ba6d15e Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Wed, 4 Mar 2020 00:33:51 +0100 Subject: Make matching distance header-only. --- matching/include/bifiltration.h | 72 +++--- matching/include/bifiltration.hpp | 421 ++++++++++++++++++++++++++++++++ matching/include/box.h | 77 ++---- matching/include/box.hpp | 52 ++++ matching/include/cell_with_value.h | 47 +++- matching/include/cell_with_value.hpp | 224 +++++++++++++++++ matching/include/common_defs.h | 4 +- matching/include/common_util.h | 113 ++++----- matching/include/common_util.hpp | 96 ++++++++ matching/include/dual_box.h | 78 ++---- matching/include/dual_box.hpp | 190 ++++++++++++++ matching/include/dual_point.h | 27 +- matching/include/dual_point.hpp | 299 +++++++++++++++++++++++ matching/include/matching_distance.h | 203 +++++++++++---- matching/include/matching_distance.hpp | 326 ++++++++++++++----------- matching/include/persistence_module.h | 34 ++- matching/include/persistence_module.hpp | 177 ++++++++++++++ matching/include/simplex.h | 70 ++++-- matching/include/simplex.hpp | 79 ++++++ 19 files changed, 2117 insertions(+), 472 deletions(-) create mode 100644 matching/include/bifiltration.hpp create mode 100644 matching/include/box.hpp create mode 100644 matching/include/cell_with_value.hpp create mode 100644 matching/include/common_util.hpp create mode 100644 matching/include/dual_box.hpp create mode 100644 matching/include/dual_point.hpp create mode 100644 matching/include/persistence_module.hpp create mode 100644 matching/include/simplex.hpp (limited to 'matching/include') diff --git a/matching/include/bifiltration.h b/matching/include/bifiltration.h index f505ed9..4dd8662 100644 --- a/matching/include/bifiltration.h +++ b/matching/include/bifiltration.h @@ -3,19 +3,30 @@ #include #include +#include +#include +#include +#include #include "common_util.h" #include "box.h" #include "simplex.h" #include "dual_point.h" +#include "phat/boundary_matrix.h" +#include "phat/compute_persistence_pairs.h" + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/fmt.h" +#include "spdlog/fmt/ostr.h" + +#include "common_util.h" namespace md { + template class Bifiltration { public: - using Diagram = std::vector>; - using Box = md::Box; - using SimplexVector = std::vector; + using SimplexVector = std::vector>; Bifiltration() = default; @@ -36,7 +47,7 @@ namespace md { init(); } - Diagram weighted_slice_diagram(const DualPoint& line, int dim) const; + Diagram weighted_slice_diagram(const DualPoint& line, int dim) const; SimplexVector simplices() const { return simplices_; } @@ -48,14 +59,12 @@ namespace md { Real minimal_coordinate() const; // return box that contains positions of all simplices - Box bounding_box() const; + Box bounding_box() const; void sanity_check() const; int maximal_dim() const { return maximal_dim_; } - friend std::ostream& operator<<(std::ostream& os, const Bifiltration& bif); - Real max_x() const; Real max_y() const; @@ -64,7 +73,7 @@ namespace md { Real min_y() const; - void add_simplex(Index _id, Point birth, int _dim, const Column& _bdry); + void add_simplex(Index _id, Point birth, int _dim, const Column& _bdry); void save(const std::string& filename, BifiltrationFormat format = BifiltrationFormat::rivet); // save to file @@ -72,11 +81,8 @@ namespace md { private: SimplexVector simplices_; - // axes names, for rivet bifiltration format only - std::string parameter_1_name_ {"axis_1"}; - std::string parameter_2_name_ {"axis_2"}; - Box bounding_box_; + Box bounding_box_; int maximal_dim_ {-1}; void init(); @@ -97,13 +103,15 @@ namespace md { }; - std::ostream& operator<<(std::ostream& os, const Bifiltration& bif); + template + std::ostream& operator<<(std::ostream& os, const Bifiltration& bif); + template class BifiltrationProxy { public: - BifiltrationProxy(const Bifiltration& bif, int dim = 0); + BifiltrationProxy(const Bifiltration& bif, int dim = 0); // return critical values of simplices that are important for current dimension (dim and dim+1) - PointVec positions() const; + PointVec positions() const; // set current dimension int set_dim(int new_dim); @@ -111,46 +119,22 @@ namespace md { int maximal_dim() const; void translate(Real a); Real minimal_coordinate() const; - Box bounding_box() const; + Box bounding_box() const; Real max_x() const; Real max_y() const; Real min_x() const; Real min_y() const; - Diagram weighted_slice_diagram(const DualPoint& slice) const; + Diagram weighted_slice_diagram(const DualPoint& slice) const; private: int dim_ { 0 }; - mutable PointVec cached_positions_; - Bifiltration bif_; + mutable PointVec cached_positions_; + Bifiltration bif_; void cache_positions() const; }; } - +#include "bifiltration.hpp" #endif //MATCHING_DISTANCE_BIFILTRATION_H - -//// The value type of OutputIterator is Simplex_in_2D_filtration -//template -//void read_input(std::string filename, OutputIterator out) -//{ -// std::ifstream ifstr; -// ifstr.open(filename.c_str()); -// long n; -// ifstr >> n; // number of simplices is the first number in file -// -// Index k; // used in loop -// for (int i = 0; i < n; i++) { -// Simplex_in_2D_filtration next; -// next.index = i; -// ifstr >> next.dim >> next.pos.x >> next.pos.y; -// if (next.dim > 0) { -// for (int j = 0; j <= next.dim; j++) { -// ifstr >> k; -// next.bd.push_back(k); -// } -// } -// *out++ = next; -// } -//} diff --git a/matching/include/bifiltration.hpp b/matching/include/bifiltration.hpp new file mode 100644 index 0000000..9e2a82e --- /dev/null +++ b/matching/include/bifiltration.hpp @@ -0,0 +1,421 @@ +namespace md { + + template + void Bifiltration::init() + { + auto lower_left = max_point(); + auto upper_right = min_point(); + for(const auto& simplex : simplices_) { + lower_left = greatest_lower_bound<>(lower_left, simplex.position()); + upper_right = least_upper_bound<>(upper_right, simplex.position()); + maximal_dim_ = std::max(maximal_dim_, simplex.dim()); + } + bounding_box_ = Box(lower_left, upper_right); + } + + template + Bifiltration::Bifiltration(const std::string& fname) + { + std::ifstream ifstr {fname.c_str()}; + if (!ifstr.good()) { + std::string error_message = fmt::format("Cannot open file {0}", fname); + std::cerr << error_message << std::endl; + throw std::runtime_error(error_message); + } + + BifiltrationFormat input_format; + + std::string s; + + while(ignore_line(s)) { + std::getline(ifstr, s); + } + + if (s == "bifiltration") { + input_format = BifiltrationFormat::rivet; + } else if (s == "bifiltration_phat_like") { + input_format = BifiltrationFormat::phat_like; + } else { + std::cerr << "Unknown format: '" << s << "' in file " << fname << std::endl; + throw std::runtime_error("unknown bifiltration format"); + } + + switch(input_format) { + case BifiltrationFormat::rivet : + rivet_format_reader(ifstr); + break; + case BifiltrationFormat::phat_like : + phat_like_format_reader(ifstr); + break; + } + + ifstr.close(); + + init(); + } + + template + void Bifiltration::rivet_format_reader(std::ifstream& ifstr) + { + std::string s; + // read axes names, ignore them + std::getline(ifstr, s); + std::getline(ifstr, s); + + Index index = 0; + while(std::getline(ifstr, s)) { + if (!ignore_line(s)) { + simplices_.emplace_back(index++, s, BifiltrationFormat::rivet); + } + } + } + + template + void Bifiltration::phat_like_format_reader(std::ifstream& ifstr) + { + spd::debug("Enter phat_like_format_reader"); + // read stream line by line; do not use >> operator + std::string s; + std::getline(ifstr, s); + + // first line contains number of simplices + long n_simplices = std::stol(s); + + // all other lines represent a simplex + Index index = 0; + while(index < n_simplices) { + std::getline(ifstr, s); + if (!ignore_line(s)) { + simplices_.emplace_back(index++, s, BifiltrationFormat::phat_like); + } + } + spd::debug("Read {} simplices from file", n_simplices); + } + + template + void Bifiltration::scale(Real lambda) + { + for(auto& s : simplices_) { + s.scale(lambda); + } + init(); + } + + template + void Bifiltration::sanity_check() const + { +#ifdef DEBUG + spd::debug("Enter Bifiltration::sanity_check"); + // check that boundary has correct number of simplices, + // each bounding simplex has correct dim + // and appears in the filtration before the simplex it bounds + for(const auto& s : simplices_) { + assert(s.dim() >= 0); + assert(s.dim() == 0 or s.dim() + 1 == (int) s.boundary().size()); + for(auto bdry_idx : s.boundary()) { + Simplex bdry_simplex = simplices()[bdry_idx]; + assert(bdry_simplex.dim() == s.dim() - 1); + assert(bdry_simplex.position().is_less(s.position(), false)); + } + } + spd::debug("Exit Bifiltration::sanity_check"); +#endif + } + + template + Diagram Bifiltration::weighted_slice_diagram(const DualPoint& line, int dim) const + { + DiagramKeeper dgm; + + // make a copy for now; I want slice_diagram to be const + std::vector> simplices(simplices_); + +// std::vector simplices; +// simplices.reserve(simplices_.size() / 2); +// for(const auto& s : simplices_) { +// if (s.dim() <= dim + 1 and s.dim() >= dim) +// simplices.emplace_back(s); +// } + + spd::debug("Enter slice diagram, line = {}, simplices.size = {}", line, simplices.size()); + + for(auto& simplex : simplices) { + Real value = line.weighted_push(simplex.position()); +// spd::debug("in slice_diagram, simplex = {}, value = {}\n", simplex, value); + simplex.set_value(value); + } + + std::sort(simplices.begin(), simplices.end(), + [](const Simplex& a, const Simplex& b) { return a.value() < b.value(); }); + std::map index_map; + for(Index i = 0; i < (int) simplices.size(); i++) { + index_map[simplices[i].id()] = i; + } + + phat::boundary_matrix<> phat_matrix; + phat_matrix.set_num_cols(simplices.size()); + std::vector bd_in_slice_filtration; + for(Index i = 0; i < (int) simplices.size(); i++) { + phat_matrix.set_dim(i, simplices[i].dim()); + bd_in_slice_filtration.clear(); + //std::cout << "new col" << i << std::endl; + for(int j = 0; j < (int) simplices[i].boundary().size(); j++) { + // F[i] contains the indices of its facet wrt to the + // original filtration. We have to express it, however, + // wrt to the filtration along the slice. That is why + // we need the index_map + //std::cout << "Found " << F[i].bd[j] << ", returning " << index_map[F[i].bd[j]] << std::endl; + bd_in_slice_filtration.push_back(index_map[simplices[i].boundary()[j]]); + } + std::sort(bd_in_slice_filtration.begin(), bd_in_slice_filtration.end()); + phat_matrix.set_col(i, bd_in_slice_filtration); + } + phat::persistence_pairs phat_persistence_pairs; + phat::compute_persistence_pairs(phat_persistence_pairs, phat_matrix); + + dgm.clear(); + constexpr Real real_inf = std::numeric_limits::infinity(); + for(long i = 0; i < (long) phat_persistence_pairs.get_num_pairs(); i++) { + std::pair new_pair = phat_persistence_pairs.get_pair(i); + bool is_finite_pair = new_pair.second != phat::k_infinity_index; + Real birth = simplices.at(new_pair.first).value(); + Real death = is_finite_pair ? simplices.at(new_pair.second).value() : real_inf; + int dim = simplices[new_pair.first].dim(); + assert(dim + 1 == simplices[new_pair.second].dim()); + if (birth != death) { + dgm.add_point(dim, birth, death); + } + } + + spdlog::debug("Exiting slice_diagram, #dgm[0] = {}", dgm.get_diagram(0).size()); + + return dgm.get_diagram(dim); + } + + template + Box Bifiltration::bounding_box() const + { + return bounding_box_; + } + + template + Real Bifiltration::minimal_coordinate() const + { + return std::min(bounding_box_.lower_left().x, bounding_box_.lower_left().y); + } + + template + void Bifiltration::translate(Real a) + { + bounding_box_.translate(a); + for(auto& simplex : simplices_) { + simplex.translate(a); + } + } + + template + Real Bifiltration::max_x() const + { + if (simplices_.empty()) + return 1; + auto me = std::max_element(simplices_.cbegin(), simplices_.cend(), + [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; }); + assert(me != simplices_.cend()); + return me->position().x; + } + + template + Real Bifiltration::max_y() const + { + if (simplices_.empty()) + return 1; + auto me = std::max_element(simplices_.cbegin(), simplices_.cend(), + [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; }); + assert(me != simplices_.cend()); + return me->position().y; + } + + template + Real Bifiltration::min_x() const + { + if (simplices_.empty()) + return 0; + auto me = std::min_element(simplices_.cbegin(), simplices_.cend(), + [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; }); + assert(me != simplices_.cend()); + return me->position().x; + } + + template + Real Bifiltration::min_y() const + { + if (simplices_.empty()) + return 0; + auto me = std::min_element(simplices_.cbegin(), simplices_.cend(), + [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; }); + assert(me != simplices_.cend()); + return me->position().y; + } + + template + void Bifiltration::add_simplex(Index _id, Point birth, int _dim, const Column& _bdry) + { + simplices_.emplace_back(_id, birth, _dim, _bdry); + } + + template + void Bifiltration::save(const std::string& filename, md::BifiltrationFormat format) + { + switch(format) { + case BifiltrationFormat::rivet: + throw std::runtime_error("Not implemented"); + break; + case BifiltrationFormat::phat_like: { + std::ofstream f(filename); + if (not f.good()) { + std::cerr << "Bifiltration::save: cannot open file " << filename << std::endl; + throw std::runtime_error("Cannot open file for writing "); + } + f << simplices_.size() << "\n"; + + for(const auto& s : simplices_) { + f << s.dim() << " " << s.position().x << " " << s.position().y << " "; + for(int b : s.boundary()) { + f << b << " "; + } + f << std::endl; + } + + } + break; + } + } + + template + void Bifiltration::postprocess_rivet_format() + { + std::map facets_to_ids; + + // fill the map + for(Index i = 0; i < (Index) simplices_.size(); ++i) { + assert(simplices_[i].id() == i); + facets_to_ids[simplices_[i].vertices_] = i; + } + +// for(const auto& s : simplices_) { +// facets_to_ids[s] = s.id(); +// } + + // main loop + for(auto& s : simplices_) { + assert(not s.vertices_.empty()); + assert(s.facet_indices_.empty()); + Column facet_indices; + for(Index i = 0; i <= s.dim(); ++i) { + Column facet; + for(Index j : s.vertices_) { + if (j != i) + facet.push_back(j); + } + auto facet_index = facets_to_ids.at(facet); + facet_indices.push_back(facet_index); + } + s.facet_indices_ = facet_indices; + } // loop over simplices + } + + template + std::ostream& operator<<(std::ostream& os, const Bifiltration& bif) + { + os << "Bifiltration [" << std::endl; + for(const auto& s : bif.simplices()) { + os << s << std::endl; + } + os << "]" << std::endl; + return os; + } + + template + BifiltrationProxy::BifiltrationProxy(const Bifiltration& bif, int dim) + : + dim_(dim), + bif_(bif) + { + cache_positions(); + } + + template + void BifiltrationProxy::cache_positions() const + { + cached_positions_.clear(); + for(const auto& simplex : bif_.simplices()) { + if (simplex.dim() == dim_ or simplex.dim() == dim_ + 1) + cached_positions_.push_back(simplex.position()); + } + } + + template + PointVec + BifiltrationProxy::positions() const + { + if (cached_positions_.empty()) { + cache_positions(); + } + return cached_positions_; + } + + // translate all points by vector (a,a) + template + void BifiltrationProxy::translate(Real a) + { + bif_.translate(a); + } + + // return minimal value of x- and y-coordinates + // among all simplices + template + Real BifiltrationProxy::minimal_coordinate() const + { + return bif_.minimal_coordinate(); + } + + // return box that contains positions of all simplices + template + Box BifiltrationProxy::bounding_box() const + { + return bif_.bounding_box(); + } + + template + Real BifiltrationProxy::max_x() const + { + return bif_.max_x(); + } + + template + Real BifiltrationProxy::max_y() const + { + return bif_.max_y(); + } + + template + Real BifiltrationProxy::min_x() const + { + return bif_.min_x(); + } + + template + Real BifiltrationProxy::min_y() const + { + return bif_.min_y(); + } + + + template + Diagram BifiltrationProxy::weighted_slice_diagram(const DualPoint& slice) const + { + return bif_.weighted_slice_diagram(slice, dim_); + } + +} + diff --git a/matching/include/box.h b/matching/include/box.h index 2990fba..4243667 100644 --- a/matching/include/box.h +++ b/matching/include/box.h @@ -8,20 +8,23 @@ namespace md { + template struct Box { + public: + using Real = Real_; private: - Point ll; - Point ur; + Point ll; + Point ur; public: - Box(Point ll = Point(), Point ur = Point()) + Box(Point ll = Point(), Point ur = Point()) :ll(ll), ur(ur) { } - Box(Point center, Real width, Real height) : - ll(Point(center.x - 0.5 * width, center.y - 0.5 * height)), - ur(Point(center.x + 0.5 * width, center.y + 0.5 * height)) + Box(Point center, Real width, Real height) : + ll(Point(center.x - 0.5 * width, center.y - 0.5 * height)), + ur(Point(center.x + 0.5 * width, center.y + 0.5 * height)) { } @@ -30,11 +33,9 @@ namespace md { inline double height() const { return ur.y - ll.y; } - inline Point lower_left() const { return ll; } - inline Point upper_right() const { return ur; } - inline Point center() const { return Point((ll.x + ur.x) / 2, (ll.y + ur.y) / 2); } - -// bool inside(Point& p) const { return ll.x <= p.x && ll.y <= p.y && ur.x >= p.x && ur.y >= p.y; } + inline Point lower_left() const { return ll; } + inline Point upper_right() const { return ur; } + inline Point center() const { return Point((ll.x + ur.x) / 2, (ll.y + ur.y) / 2); } inline bool operator==(const Box& p) { @@ -43,58 +44,16 @@ namespace md { std::vector refine() const; - std::vector corners() const; + std::vector> corners() const; void translate(Real a); - - // return minimal and maximal value of func - // on the corners of the box - template - std::pair min_max_on_corners(const F& func) const; - - friend std::ostream& operator<<(std::ostream& os, const Box& box); }; - std::ostream& operator<<(std::ostream& os, const Box& box); -// template -// Box compute_bounding_box(InputIterator simplices_begin, InputIterator simplices_end) -// { -// if (simplices_begin == simplices_end) { -// return Box(); -// } -// Box bb; -// bb.ll = bb.ur = simplices_begin->pos; -// for (InputIterator it = simplices_begin; it != simplices_end; it++) { -// Point& pos = it->pos; -// if (pos.x < bb.ll.x) { -// bb.ll.x = pos.x; -// } -// if (pos.y < bb.ll.y) { -// bb.ll.y = pos.y; -// } -// if (pos.x > bb.ur.x) { -// bb.ur.x = pos.x; -// } -// if (pos.y > bb.ur.y) { -// bb.ur.y = pos.y; -// } -// } -// return bb; -// } - - Box get_enclosing_box(const Box& box_a, const Box& box_b); - - template - std::pair Box::min_max_on_corners(const F& func) const - { - std::pair min_max { std::numeric_limits::max(), -std::numeric_limits::max() }; - for(Point p : corners()) { - Real value = func(p); - min_max.first = std::min(min_max.first, value); - min_max.second = std::max(min_max.second, value); - } - return min_max; - }; + template + std::ostream& operator<<(std::ostream& os, const Box& box); + } // namespace md +#include "box.hpp" + #endif //MATCHING_DISTANCE_BOX_H diff --git a/matching/include/box.hpp b/matching/include/box.hpp new file mode 100644 index 0000000..f551d84 --- /dev/null +++ b/matching/include/box.hpp @@ -0,0 +1,52 @@ +namespace md { + + template + std::ostream& operator<<(std::ostream& os, const Box& box) + { + os << "Box(lower_left = " << box.lower_left() << ", upper_right = " << box.upper_right() << ")"; + return os; + } + + template + void Box::translate(Real a) + { + ll.x += a; + ll.y += a; + ur.x += a; + ur.y += a; + } + + template + std::vector> Box::refine() const + { + std::vector> result; + +// 1 | 2 +// 0 | 3 + + Point new_ll = lower_left(); + Point new_ur = center(); + result.emplace_back(new_ll, new_ur); + + new_ll.y = center().y; + new_ur.y = ur.y; + result.emplace_back(new_ll, new_ur); + + new_ll = center(); + new_ur = upper_right(); + result.emplace_back(new_ll, new_ur); + + new_ll.y = ll.y; + new_ur.y = center().y; + result.emplace_back(new_ll, new_ur); + + return result; + } + + template + std::vector> Box::corners() const + { + return {ll, Point(ll.x, ur.y), ur, Point(ur.x, ll.y)}; + }; + +} diff --git a/matching/include/cell_with_value.h b/matching/include/cell_with_value.h index 25644d1..3548a11 100644 --- a/matching/include/cell_with_value.h +++ b/matching/include/cell_with_value.h @@ -1,7 +1,3 @@ -// -// Created by narn on 16.07.19. -// - #ifndef MATCHING_DISTANCE_CELL_WITH_VALUE_H #define MATCHING_DISTANCE_CELL_WITH_VALUE_H @@ -21,7 +17,29 @@ namespace md { upper_right }; - std::ostream& operator<<(std::ostream& os, const ValuePoint& vp); + inline std::ostream& operator<<(std::ostream& os, const ValuePoint& vp) + { + switch(vp) { + case ValuePoint::upper_left : + os << "upper_left"; + break; + case ValuePoint::upper_right : + os << "upper_right"; + break; + case ValuePoint::lower_left : + os << "lower_left"; + break; + case ValuePoint::lower_right : + os << "lower_right"; + break; + case ValuePoint::center: + os << "center"; + break; + default: + os << "FORGOTTEN ValuePoint"; + } + return os; + } const std::vector k_all_vps = {ValuePoint::center, ValuePoint::lower_left, ValuePoint::upper_left, ValuePoint::upper_right, ValuePoint::lower_right}; @@ -31,8 +49,10 @@ namespace md { // represents a cell in the dual space with the value // of the weighted bottleneck distance + template class CellWithValue { public: + using Real = Real_; CellWithValue() = default; @@ -44,18 +64,18 @@ namespace md { CellWithValue& operator=(CellWithValue&& other) = default; - CellWithValue(const DualBox& b, int level) + CellWithValue(const DualBox& b, int level) :dual_box_(b), level_(level) { } - DualBox dual_box() const { return dual_box_; } + DualBox dual_box() const { return dual_box_; } - DualPoint center() const { return dual_box_.center(); } + DualPoint center() const { return dual_box_.center(); } Real value_at(ValuePoint vp) const; bool has_value_at(ValuePoint vp) const; - DualPoint value_point(ValuePoint vp) const; + DualPoint value_point(ValuePoint vp) const; int level() const { return level_; } @@ -73,8 +93,6 @@ namespace md { std::vector get_refined_cells() const; - friend std::ostream& operator<<(std::ostream&, const CellWithValue&); - void set_max_possible_value(Real new_upper_bound); int num_values() const; @@ -100,7 +118,7 @@ namespace md { bool has_upper_right_value() const { return upper_right_value_ >= 0; } - DualBox dual_box_; + DualBox dual_box_; Real central_value_ {-1.0}; Real lower_left_value_ {-1.0}; Real lower_right_value_ {-1.0}; @@ -114,7 +132,10 @@ namespace md { bool has_max_possible_value_ {false}; }; - std::ostream& operator<<(std::ostream& os, const CellWithValue& cell); + template + std::ostream& operator<<(std::ostream& os, const CellWithValue& cell); } // namespace md +#include "cell_with_value.hpp" + #endif //MATCHING_DISTANCE_CELL_WITH_VALUE_H diff --git a/matching/include/cell_with_value.hpp b/matching/include/cell_with_value.hpp new file mode 100644 index 0000000..88b2569 --- /dev/null +++ b/matching/include/cell_with_value.hpp @@ -0,0 +1,224 @@ +namespace md { + +#ifdef MD_DEBUG + long long int CellWithValue::max_id = 0; +#endif + + template + Real CellWithValue::value_at(ValuePoint vp) const + { + switch(vp) { + case ValuePoint::upper_left : + return upper_left_value_; + case ValuePoint::upper_right : + return upper_right_value_; + case ValuePoint::lower_left : + return lower_left_value_; + case ValuePoint::lower_right : + return lower_right_value_; + case ValuePoint::center: + return central_value_; + } + // to shut up compiler warning + return 1.0 / 0.0; + } + + template + bool CellWithValue::has_value_at(ValuePoint vp) const + { + switch(vp) { + case ValuePoint::upper_left : + return upper_left_value_ >= 0; + case ValuePoint::upper_right : + return upper_right_value_ >= 0; + case ValuePoint::lower_left : + return lower_left_value_ >= 0; + case ValuePoint::lower_right : + return lower_right_value_ >= 0; + case ValuePoint::center: + return central_value_ >= 0; + } + // to shut up compiler warning + return 1.0 / 0.0; + } + + template + DualPoint CellWithValue::value_point(md::ValuePoint vp) const + { + switch(vp) { + case ValuePoint::upper_left : + return dual_box().upper_left(); + case ValuePoint::upper_right : + return dual_box().upper_right(); + case ValuePoint::lower_left : + return dual_box().lower_left(); + case ValuePoint::lower_right : + return dual_box().lower_right(); + case ValuePoint::center: + return dual_box().center(); + } + // to shut up compiler warning + return DualPoint(); + } + + template + bool CellWithValue::has_corner_value() const + { + return has_lower_left_value() or has_lower_right_value() or has_upper_left_value() + or has_upper_right_value(); + } + + template + Real CellWithValue::stored_upper_bound() const + { + assert(has_max_possible_value_); + return max_possible_value_; + } + + template + Real CellWithValue::max_corner_value() const + { + return std::max({lower_left_value_, lower_right_value_, upper_left_value_, upper_right_value_}); + } + + template + Real CellWithValue::min_value() const + { + Real result = std::numeric_limits::max(); + for(auto vp : k_all_vps) { + if (not has_value_at(vp)) { + continue; + } + result = std::min(result, value_at(vp)); + } + return result; + } + + template + std::vector> CellWithValue::get_refined_cells() const + { + std::vector> result; + result.reserve(4); + for(const auto& refined_box : dual_box_.refine()) { + + CellWithValue refined_cell(refined_box, level() + 1); + +#ifdef MD_DEBUG + refined_cell.parent_ids = parent_ids; + refined_cell.parent_ids.push_back(id); + refined_cell.id = ++max_id; +#endif + + if (refined_box.lower_left() == dual_box_.lower_left()) { + // _|_ + // H|_ + + refined_cell.set_value_at(ValuePoint::lower_left, lower_left_value_); + refined_cell.set_value_at(ValuePoint::upper_right, central_value_); + + } else if (refined_box.upper_right() == dual_box_.upper_right()) { + // _|H + // _|_ + + refined_cell.set_value_at(ValuePoint::lower_left, central_value_); + refined_cell.set_value_at(ValuePoint::upper_right, upper_right_value_); + + } else if (refined_box.lower_right() == dual_box_.lower_right()) { + // _|_ + // _|H + + refined_cell.set_value_at(ValuePoint::lower_right, lower_right_value_); + refined_cell.set_value_at(ValuePoint::upper_left, central_value_); + + } else if (refined_box.upper_left() == dual_box_.upper_left()) { + + // H|_ + // _|_ + + refined_cell.set_value_at(ValuePoint::lower_right, central_value_); + refined_cell.set_value_at(ValuePoint::upper_left, upper_left_value_); + } + result.emplace_back(refined_cell); + } + return result; + } + + template + void CellWithValue::set_value_at(ValuePoint vp, Real new_value) + { + if (has_value_at(vp)) + spd::error("CellWithValue: trying to re-assign value!, this = {}, vp = {}", *this, vp); + + switch(vp) { + case ValuePoint::upper_left : + upper_left_value_ = new_value; + break; + case ValuePoint::upper_right : + upper_right_value_ = new_value; + break; + case ValuePoint::lower_left : + lower_left_value_ = new_value; + break; + case ValuePoint::lower_right : + lower_right_value_ = new_value; + break; + case ValuePoint::center: + central_value_ = new_value; + break; + } + } + + template + int CellWithValue::num_values() const + { + int result = 0; + for(ValuePoint vp : k_all_vps) { + result += has_value_at(vp); + } + return result; + } + + + template + void CellWithValue::set_max_possible_value(Real new_upper_bound) + { + assert(new_upper_bound >= central_value_); + assert(new_upper_bound >= lower_left_value_); + assert(new_upper_bound >= lower_right_value_); + assert(new_upper_bound >= upper_left_value_); + assert(new_upper_bound >= upper_right_value_); + has_max_possible_value_ = true; + max_possible_value_ = new_upper_bound; + } + + + template + std::ostream& operator<<(std::ostream& os, const CellWithValue& cell) + { + os << "CellWithValue(box = " << cell.dual_box() << ", "; + +#ifdef MD_DEBUG + os << "id = " << cell.id; + if (not cell.parent_ids.empty()) + os << ", parent_ids = " << container_to_string(cell.parent_ids) << ", "; +#endif + + for(ValuePoint vp : k_all_vps) { + if (cell.has_value_at(vp)) { + os << "value = " << cell.value_at(vp); + os << ", at " << vp << " " << cell.value_point(vp); + } + } + + os << ", max_corner_value = "; + if (cell.has_max_possible_value()) { + os << cell.stored_upper_bound(); + } else { + os << "-"; + } + + os << ", level = " << cell.level() << ")"; + return os; + } + +} // namespace md diff --git a/matching/include/common_defs.h b/matching/include/common_defs.h index 3f3d937..8d01325 100644 --- a/matching/include/common_defs.h +++ b/matching/include/common_defs.h @@ -1,8 +1,8 @@ #ifndef MATCHING_DISTANCE_DEF_DEBUG_H #define MATCHING_DISTANCE_DEF_DEBUG_H -//#define EXPERIMENTAL_TIMING -//#define PRINT_HEAT_MAP +//#define MD_EXPERIMENTAL_TIMING +//#define MD_PRINT_HEAT_MAP //#define MD_DEBUG //#define MD_DO_CHECKS //#define MD_DO_FULL_CHECK diff --git a/matching/include/common_util.h b/matching/include/common_util.h index 2d8dcb0..778536f 100644 --- a/matching/include/common_util.h +++ b/matching/include/common_util.h @@ -11,22 +11,24 @@ #include #include +#include +#include + +namespace spd = spdlog; + #include "common_defs.h" #include "phat/helpers/misc.h" - namespace md { - - using Real = double; - using RealVec = std::vector; using Index = phat::index; using IndexVec = std::vector; - static constexpr Real pi = M_PI; + //static constexpr Real pi = M_PI; using Column = std::vector; + template struct Point { Real x; Real y; @@ -71,59 +73,56 @@ namespace md { }; - using PointVec = std::vector; - - Point operator+(const Point& u, const Point& v); + template + using PointVec = std::vector>; - Point operator-(const Point& u, const Point& v); + template + Point operator+(const Point& u, const Point& v); - Point least_upper_bound(const Point& u, const Point& v); + template + Point operator-(const Point& u, const Point& v); - Point greatest_lower_bound(const Point& u, const Point& v); - Point max_point(); + template + Point least_upper_bound(const Point& u, const Point& v); - Point min_point(); + template + Point greatest_lower_bound(const Point& u, const Point& v); - std::ostream& operator<<(std::ostream& ostr, const Point& vec); + template + Point max_point(); - Real L_infty(const Point& v); + template + Point min_point(); - Real l_2_norm(const Point& v); + template + std::ostream& operator<<(std::ostream& ostr, const Point& vec); - Real l_2_dist(const Point& x, const Point& y); + template + using DiagramPoint = std::pair; - Real l_infty_dist(const Point& x, const Point& y); + template + using Diagram = std::vector>; - using Interval = std::pair; - - // return minimal interval that contains both a and b - inline Interval minimal_covering_interval(Interval a, Interval b) - { - return {std::min(a.first, b.first), std::max(a.second, b.second)}; - } // to keep diagrams in all dimensions // TODO: store in Hera format? + template class DiagramKeeper { public: - using DiagramPoint = std::pair; - using Diagram = std::vector; DiagramKeeper() { }; void add_point(int dim, Real birth, Real death); - Diagram get_diagram(int dim) const; + Diagram get_diagram(int dim) const; void clear() { data_.clear(); } private: - std::map data_; + std::map> data_; }; - using Diagram = std::vector>; - template std::string container_to_string(const C& cont) { @@ -140,42 +139,18 @@ namespace md { return ss.str(); } - int gcd(int a, int b); - - struct Rational { - int numerator {0}; - int denominator {1}; - Rational() = default; - Rational(int n, int d) : numerator(n / gcd(n, d)), denominator(d / gcd(n, d)) {} - Rational(std::pair p) : Rational(p.first, p.second) {} - Rational(int n) : numerator(n), denominator(1) {} - Real to_real() const { return (Real)numerator / (Real)denominator; } - void reduce(); - Rational& operator+=(const Rational& rhs); - Rational& operator-=(const Rational& rhs); - Rational& operator*=(const Rational& rhs); - Rational& operator/=(const Rational& rhs); - }; - - using namespace std::rel_ops; - - bool operator==(const Rational& a, const Rational& b); - bool operator<(const Rational& a, const Rational& b); - std::ostream& operator<<(std::ostream& os, const Rational& a); - - // arithmetic - Rational operator+(Rational a, const Rational& b); - Rational operator-(Rational a, const Rational& b); - Rational operator*(Rational a, const Rational& b); - Rational operator/(Rational a, const Rational& b); - - Rational reduce(Rational frac); - - Rational midpoint(Rational a, Rational b); - // return true, if s is empty or starts with # (commented out line) // whitespaces in the beginning of s are ignored - bool ignore_line(const std::string& s); + inline bool ignore_line(const std::string& s) + { + for(auto c : s) { + if (isspace(c)) + continue; + return (c == '#'); + } + return true; + } + // split string by delimeter template @@ -195,10 +170,10 @@ namespace md { } namespace std { - template<> - struct hash + template + struct hash> { - std::size_t operator()(const md::Point& p) const + std::size_t operator()(const md::Point& p) const { auto hx = std::hash()(p.x); auto hy = std::hash()(p.y); @@ -207,5 +182,7 @@ namespace std { }; }; +#include "common_util.hpp" + #endif //MATCHING_DISTANCE_COMMON_UTIL_H diff --git a/matching/include/common_util.hpp b/matching/include/common_util.hpp new file mode 100644 index 0000000..76d97af --- /dev/null +++ b/matching/include/common_util.hpp @@ -0,0 +1,96 @@ +#include +#include +#include +#include +#include +#include + +#include + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +namespace md { + + template + Point operator+(const Point& u, const Point& v) + { + return Point(u.x + v.x, u.y + v.y); + } + + template + Point operator-(const Point& u, const Point& v) + { + return Point(u.x - v.x, u.y - v.y); + } + + template + Point least_upper_bound(const Point& u, const Point& v) + { + return Point(std::max(u.x, v.x), std::max(u.y, v.y)); + } + + template + Point greatest_lower_bound(const Point& u, const Point& v) + { + return Point(std::min(u.x, v.x), std::min(u.y, v.y)); + } + + template + Point max_point() + { + return Point(std::numeric_limits::max(), std::numeric_limits::min()); + } + + template + Point min_point() + { + return Point(-std::numeric_limits::max(), -std::numeric_limits::min()); + } + + template + std::ostream& operator<<(std::ostream& ostr, const Point& vec) + { + ostr << "(" << vec.x << ", " << vec.y << ")"; + return ostr; + } + + template + Real l_infty_norm(const Point& v) + { + return std::max(std::abs(v.x), std::abs(v.y)); + } + + template + Real l_2_norm(const Point& v) + { + return v.norm(); + } + + template + Real l_2_dist(const Point& x, const Point& y) + { + return l_2_norm(x - y); + } + + template + Real l_infty_dist(const Point& x, const Point& y) + { + return l_infty_norm(x - y); + } + + template + void DiagramKeeper::add_point(int dim, Real birth, Real death) + { + data_[dim].emplace_back(birth, death); + } + + template + Diagram DiagramKeeper::get_diagram(int dim) const + { + if (data_.count(dim) == 1) + return data_.at(dim); + else + return Diagram(); + } +} diff --git a/matching/include/dual_box.h b/matching/include/dual_box.h index ce0384d..0e4f4d5 100644 --- a/matching/include/dual_box.h +++ b/matching/include/dual_box.h @@ -4,16 +4,23 @@ #include #include #include +#include + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + #include "common_util.h" #include "dual_point.h" namespace md { + + template class DualBox { public: - DualBox(DualPoint ll, DualPoint ur); + DualBox(DualPoint ll, DualPoint ur); DualBox() = default; DualBox(const DualBox&) = default; @@ -23,12 +30,12 @@ namespace md { DualBox& operator=(DualBox&& other) = default; - DualPoint center() const { return midpoint(lower_left_, upper_right_); } - DualPoint lower_left() const { return lower_left_; } - DualPoint upper_right() const { return upper_right_; } + DualPoint center() const { return midpoint(lower_left_, upper_right_); } + DualPoint lower_left() const { return lower_left_; } + DualPoint upper_right() const { return upper_right_; } - DualPoint lower_right() const; - DualPoint upper_left() const; + DualPoint lower_right() const; + DualPoint upper_left() const; AxisType axis_type() const { return lower_left_.axis_type(); } AngleType angle_type() const { return lower_left_.angle_type(); } @@ -42,66 +49,35 @@ namespace md { bool is_flat() const { return upper_right_.is_flat(); } bool is_steep() const { return lower_left_.is_steep(); } - // return minimal and maximal value of func - // on the corners of the box - template - std::pair min_max_on_corners(const F& func) const; - - template - Real max_abs_value(const F& func) const; - - std::vector refine() const; - std::vector corners() const; - std::vector critical_points(const Point& p) const; + std::vector> corners() const; + std::vector> critical_points(const Point& p) const; // sample n points from the box uniformly; for tests - std::vector random_points(int n) const; + std::vector> random_points(int n) const; // return 2 dual points at the boundary // where push changes from horizontal to vertical - std::vector push_change_points(const Point& p) const; - - friend std::ostream& operator<<(std::ostream& os, const DualBox& db); + std::vector> push_change_points(const Point& p) const; // check that a has same sign, angles are all flat or all steep bool sanity_check() const; - bool contains(const DualPoint& dp) const; + bool contains(const DualPoint& dp) const; bool operator==(const DualBox& other) const; private: - DualPoint lower_left_; - DualPoint upper_right_; + DualPoint lower_left_; + DualPoint upper_right_; }; - std::ostream& operator<<(std::ostream& os, const DualBox& db); - - template - std::pair DualBox::min_max_on_corners(const F& func) const + template + std::ostream& operator<<(std::ostream& os, const DualBox& db) { - std::pair min_max { std::numeric_limits::max(), -std::numeric_limits::max() }; - for(auto p : corners()) { - Real value = func(p); - min_max.first = std::min(min_max.first, value); - min_max.second = std::max(min_max.second, value); - } - return min_max; - }; - - - template - Real DualBox::max_abs_value(const F& func) const - { - Real result = 0; - for(auto p_1 : corners()) { - for(auto p_2 : corners()) { - Real value = fabs(func(p_1, p_2)); - result = std::max(value, result); - } - } - return result; - }; - + os << "DualBox(" << db.lower_left() << ", " << db.upper_right() << ")"; + return os; + } } +#include "dual_box.hpp" + #endif //MATCHING_DISTANCE_DUAL_BOX_H diff --git a/matching/include/dual_box.hpp b/matching/include/dual_box.hpp new file mode 100644 index 0000000..85f7f27 --- /dev/null +++ b/matching/include/dual_box.hpp @@ -0,0 +1,190 @@ +namespace md { + + template + DualBox::DualBox(DualPoint ll, DualPoint ur) + :lower_left_(ll), upper_right_(ur) + { + } + + template + std::vector> DualBox::corners() const + { + return {lower_left_, + DualPoint(axis_type(), angle_type(), lower_left_.lambda(), upper_right_.mu()), + upper_right_, + DualPoint(axis_type(), angle_type(), upper_right_.lambda(), lower_left_.mu())}; + } + + template + std::vector> DualBox::push_change_points(const Point& p) const + { + std::vector> result; + result.reserve(2); + + bool is_y_type = lower_left_.is_y_type(); + bool is_flat = lower_left_.is_flat(); + + auto mu_from_lambda = [p, is_y_type, is_flat](Real lambda) { + bool is_x_type = not is_y_type, is_steep = not is_flat; + if (is_y_type && is_flat) { + return p.y - lambda * p.x; + } else if (is_y_type && is_steep) { + return p.y - p.x / lambda; + } else if (is_x_type && is_flat) { + return p.x - p.y / lambda; + } else if (is_x_type && is_steep) { + return p.x - lambda * p.y; + } + // to shut up compiler warning + return static_cast(1.0 / 0.0); + }; + + auto lambda_from_mu = [p, is_y_type, is_flat](Real mu) { + bool is_x_type = not is_y_type, is_steep = not is_flat; + if (is_y_type && is_flat) { + return (p.y - mu) / p.x; + } else if (is_y_type && is_steep) { + return p.x / (p.y - mu); + } else if (is_x_type && is_flat) { + return p.y / (p.x - mu); + } else if (is_x_type && is_steep) { + return (p.x - mu) / p.y; + } + // to shut up compiler warning + return static_cast(1.0 / 0.0); + }; + + // all inequalities below are strict: equality means it is a corner + // && critical_points() returns corners anyway + + Real mu_intersect_min = mu_from_lambda(lambda_min()); + + if (mu_min() < mu_intersect_min && mu_intersect_min < mu_max()) + result.emplace_back(axis_type(), angle_type(), lambda_min(), mu_intersect_min); + + Real mu_intersect_max = mu_from_lambda(lambda_max()); + + if (mu_max() < mu_intersect_max && mu_intersect_max < mu_max()) + result.emplace_back(axis_type(), angle_type(), lambda_max(), mu_intersect_max); + + Real lambda_intersect_min = lambda_from_mu(mu_min()); + + if (lambda_min() < lambda_intersect_min && lambda_intersect_min < lambda_max()) + result.emplace_back(axis_type(), angle_type(), lambda_intersect_min, mu_min()); + + Real lambda_intersect_max = lambda_from_mu(mu_max()); + if (lambda_min() < lambda_intersect_max && lambda_intersect_max < lambda_max()) + result.emplace_back(axis_type(), angle_type(), lambda_intersect_max, mu_max()); + + assert(result.size() <= 2); + + if (result.size() > 2) { + fmt::print("Error in push_change_points, p = {}, dual_box = {}, result = {}\n", p, *this, + container_to_string(result)); + throw std::runtime_error("push_change_points returned more than 2 points"); + } + + return result; + } + + template + std::vector> DualBox::critical_points(const Point& /*p*/) const + { + // maximal difference is attained at corners + return corners(); +// std::vector> result; +// result.reserve(6); +// for(auto dp : corners()) result.push_back(dp); +// for(auto dp : push_change_points(p)) result.push_back(dp); +// return result; + } + + template + std::vector> DualBox::random_points(int n) const + { + assert(n >= 0); + std::mt19937_64 gen(1); + std::vector> result; + result.reserve(n); + std::uniform_real_distribution mu_distr(mu_min(), mu_max()); + std::uniform_real_distribution lambda_distr(lambda_min(), lambda_max()); + for(int i = 0; i < n; ++i) { + result.emplace_back(axis_type(), angle_type(), lambda_distr(gen), mu_distr(gen)); + } + return result; + } + + template + bool DualBox::sanity_check() const + { + lower_left_.sanity_check(); + upper_right_.sanity_check(); + + if (lower_left_.angle_type() != upper_right_.angle_type()) + throw std::runtime_error("angle types differ"); + + if (lower_left_.axis_type() != upper_right_.axis_type()) + throw std::runtime_error("axis types differ"); + + if (lower_left_.lambda() >= upper_right_.lambda()) + throw std::runtime_error("lambda of lower_left_ greater than lambda of upper_right "); + + if (lower_left_.mu() >= upper_right_.mu()) + throw std::runtime_error("mu of lower_left_ greater than mu of upper_right "); + + return true; + } + + template + std::vector> DualBox::refine() const + { + std::vector> result; + + result.reserve(4); + + Real lambda_middle = (lower_left().lambda() + upper_right().lambda()) / 2.0; + Real mu_middle = (lower_left().mu() + upper_right().mu()) / 2.0; + + DualPoint refinement_center(axis_type(), angle_type(), lambda_middle, mu_middle); + + result.emplace_back(lower_left_, refinement_center); + + result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_middle, mu_min()), + DualPoint(axis_type(), angle_type(), lambda_max(), mu_middle)); + + result.emplace_back(refinement_center, upper_right_); + + result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_min(), mu_middle), + DualPoint(axis_type(), angle_type(), lambda_middle, mu_max())); + return result; + } + + template + bool DualBox::operator==(const DualBox& other) const + { + return lower_left() == other.lower_left() && + upper_right() == other.upper_right(); + } + + template + bool DualBox::contains(const DualPoint& dp) const + { + return dp.angle_type() == angle_type() && dp.axis_type() == axis_type() && + mu_max() >= dp.mu() && + mu_min() <= dp.mu() && + lambda_min() <= dp.lambda() && + lambda_max() >= dp.lambda(); + } + + template + DualPoint DualBox::lower_right() const + { + return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_max(), mu_min()); + } + + template + DualPoint DualBox::upper_left() const + { + return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_min(), mu_max()); + } +} diff --git a/matching/include/dual_point.h b/matching/include/dual_point.h index db32f1a..8438860 100644 --- a/matching/include/dual_point.h +++ b/matching/include/dual_point.h @@ -1,12 +1,9 @@ -// -// Created by narn on 12.02.19. -// - #ifndef MATCHING_DISTANCE_DUAL_POINT_H #define MATCHING_DISTANCE_DUAL_POINT_H #include #include +#include #include "common_util.h" #include "box.h" @@ -25,9 +22,10 @@ namespace md { // so, e.g., line y = x has 4 different non-equal representation. // we are unlikely to ever need this, because 4 cases are // always treated separately. + template class DualPoint { public: - using Real = md::Real; + using Real = Real_; DualPoint() = default; @@ -56,7 +54,6 @@ namespace md { bool is_y_type() const { return axis_type_ == AxisType::y_type; } - friend std::ostream& operator<<(std::ostream& os, const DualPoint& dp); bool operator<(const DualPoint& rhs) const; AxisType axis_type() const { return axis_type_; } @@ -66,16 +63,16 @@ namespace md { // return true otherwise bool sanity_check() const; - Real weighted_push(Point p) const; - Point push(Point p) const; + Real weighted_push(Point p) const; + Point push(Point p) const; bool is_horizontal() const; bool is_vertical() const; - bool goes_below(Point p) const; - bool goes_above(Point p) const; + bool goes_below(Point p) const; + bool goes_above(Point p) const; - bool contains(Point p) const; + bool contains(Point p) const; Real x_slope() const; Real y_slope() const; @@ -98,9 +95,13 @@ namespace md { Real mu_ {-1.0}; }; - std::ostream& operator<<(std::ostream& os, const DualPoint& dp); + template + std::ostream& operator<<(std::ostream& os, const DualPoint& dp); - DualPoint midpoint(DualPoint x, DualPoint y); + template + DualPoint midpoint(DualPoint x, DualPoint y); }; +#include "dual_point.hpp" + #endif //MATCHING_DISTANCE_DUAL_POINT_H diff --git a/matching/include/dual_point.hpp b/matching/include/dual_point.hpp new file mode 100644 index 0000000..04e25f2 --- /dev/null +++ b/matching/include/dual_point.hpp @@ -0,0 +1,299 @@ +namespace md { + + inline std::ostream& operator<<(std::ostream& os, const AxisType& at) + { + if (at == AxisType::x_type) + os << "x-type"; + else + os << "y-type"; + return os; + } + + inline std::ostream& operator<<(std::ostream& os, const AngleType& at) + { + if (at == AngleType::flat) + os << "flat"; + else + os << "steep"; + return os; + } + + template + std::ostream& operator<<(std::ostream& os, const DualPoint& dp) + { + os << "Line(" << dp.axis_type() << ", "; + os << dp.angle_type() << ", "; + os << dp.lambda() << ", "; + os << dp.mu() << ", equation: "; + if (not dp.is_vertical()) { + os << "y = " << dp.y_slope() << " x + " << dp.y_intercept(); + } else { + os << "x = " << dp.x_intercept(); + } + os << " )"; + return os; + } + + template + bool DualPoint::operator<(const DualPoint& rhs) const + { + return std::tie(axis_type_, angle_type_, lambda_, mu_) + < std::tie(rhs.axis_type_, rhs.angle_type_, rhs.lambda_, rhs.mu_); + } + + template + DualPoint::DualPoint(AxisType axis_type, AngleType angle_type, Real lambda, Real mu) + : + axis_type_(axis_type), + angle_type_(angle_type), + lambda_(lambda), + mu_(mu) + { + assert(sanity_check()); + } + + template + bool DualPoint::sanity_check() const + { + if (lambda_ < 0.0) + throw std::runtime_error("Invalid line, negative lambda"); + if (lambda_ > 1.0) + throw std::runtime_error("Invalid line, lambda > 1"); + if (mu_ < 0.0) + throw std::runtime_error("Invalid line, negative mu"); + return true; + } + + template + Real DualPoint::gamma() const + { + if (is_steep()) + return atan(Real(1.0) / lambda_); + else + return atan(lambda_); + } + + template + DualPoint midpoint(DualPoint x, DualPoint y) + { + assert(x.angle_type() == y.angle_type() and x.axis_type() == y.axis_type()); + Real lambda_mid = (x.lambda() + y.lambda()) / 2; + Real mu_mid = (x.mu() + y.mu()) / 2; + return DualPoint(x.axis_type(), x.angle_type(), lambda_mid, mu_mid); + + } + + // return k in the line equation y = kx + b + template + Real DualPoint::y_slope() const + { + if (is_flat()) + return lambda(); + else + return Real(1.0) / lambda(); + } + + // return k in the line equation x = ky + b + template + Real DualPoint::x_slope() const + { + if (is_flat()) + return Real(1.0) / lambda(); + else + return lambda(); + } + + // return b in the line equation y = kx + b + template + Real DualPoint::y_intercept() const + { + if (is_y_type()) { + return mu(); + } else { + // x = x_slope * y + mu = x_slope * (y + mu / x_slope) + // x-intercept is -mu/x_slope = -mu * y_slope + return -mu() * y_slope(); + } + } + + // return k in the line equation x = ky + b + template + Real DualPoint::x_intercept() const + { + if (is_x_type()) { + return mu(); + } else { + // y = y_slope * x + mu = y_slope (x + mu / y_slope) + // x_intercept is -mu/y_slope = -mu * x_slope + return -mu() * x_slope(); + } + } + + template + Real DualPoint::x_from_y(Real y) const + { + if (is_horizontal()) + throw std::runtime_error("x_from_y called on horizontal line"); + else + return x_slope() * y + x_intercept(); + } + + template + Real DualPoint::y_from_x(Real x) const + { + if (is_vertical()) + throw std::runtime_error("x_from_y called on horizontal line"); + else + return y_slope() * x + y_intercept(); + } + + template + bool DualPoint::is_horizontal() const + { + return is_flat() and lambda() == 0; + } + + template + bool DualPoint::is_vertical() const + { + return is_steep() and lambda() == 0; + } + + template + bool DualPoint::contains(Point p) const + { + if (is_vertical()) + return p.x == x_from_y(p.y); + else + return p.y == y_from_x(p.x); + } + + template + bool DualPoint::goes_below(Point p) const + { + if (is_vertical()) + return p.x <= x_from_y(p.y); + else + return p.y >= y_from_x(p.x); + } + + template + bool DualPoint::goes_above(Point p) const + { + if (is_vertical()) + return p.x >= x_from_y(p.y); + else + return p.y <= y_from_x(p.x); + } + + template + Point DualPoint::push(Point p) const + { + Point result; + // if line is below p, we push horizontally + bool horizontal_push = goes_below(p); + if (is_x_type()) { + if (is_flat()) { + if (horizontal_push) { + result.x = p.y / lambda() + mu(); + result.y = p.y; + } else { + // vertical push + result.x = p.x; + result.y = lambda() * (p.x - mu()); + } + } else { + // steep + if (horizontal_push) { + result.x = lambda() * p.y + mu(); + result.y = p.y; + } else { + // vertical push + result.x = p.x; + result.y = (p.x - mu()) / lambda(); + } + } + } else { + // y-type + if (is_flat()) { + if (horizontal_push) { + result.x = (p.y - mu()) / lambda(); + result.y = p.y; + } else { + // vertical push + result.x = p.x; + result.y = lambda() * p.x + mu(); + } + } else { + // steep + if (horizontal_push) { + result.x = (p.y - mu()) * lambda(); + result.y = p.y; + } else { + // vertical push + result.x = p.x; + result.y = p.x / lambda() + mu(); + } + } + } + return result; + } + + template + Real DualPoint::weighted_push(Point p) const + { + // if line is below p, we push horizontally + bool horizontal_push = goes_below(p); + if (is_x_type()) { + if (is_flat()) { + if (horizontal_push) { + return p.y; + } else { + // vertical push + return lambda() * (p.x - mu()); + } + } else { + // steep + if (horizontal_push) { + return lambda() * p.y; + } else { + // vertical push + return (p.x - mu()); + } + } + } else { + // y-type + if (is_flat()) { + if (horizontal_push) { + return p.y - mu(); + } else { + // vertical push + return lambda() * p.x; + } + } else { + // steep + if (horizontal_push) { + return lambda() * (p.y - mu()); + } else { + // vertical push + return p.x; + } + } + } + } + + template + bool DualPoint::operator==(const DualPoint& other) const + { + return axis_type() == other.axis_type() and + angle_type() == other.angle_type() and + mu() == other.mu() and + lambda() == other.lambda(); + } + + template + Real DualPoint::weight() const + { + return lambda_ / sqrt(1 + lambda_ * lambda_); + } +} // namespace md diff --git a/matching/include/matching_distance.h b/matching/include/matching_distance.h index bb10203..5be34c7 100644 --- a/matching/include/matching_distance.h +++ b/matching/include/matching_distance.h @@ -4,9 +4,10 @@ #include #include #include +#include +#include +#include -#include "spdlog/spdlog.h" -#include "spdlog/fmt/ostr.h" #include "common_defs.h" #include "cell_with_value.h" @@ -17,12 +18,15 @@ #include "bifiltration.h" #include "bottleneck.h" -namespace spd = spdlog; - namespace md { - using HeatMap = std::map; - using HeatMaps = std::map; +#ifdef MD_PRINT_HEAT_MAP + template + using HeatMap = std::map, Real>; + + template + using HeatMaps = std::map>; +#endif enum class BoundStrategy { bruteforce, @@ -39,18 +43,107 @@ namespace md { upper_bound }; - std::ostream& operator<<(std::ostream& os, const BoundStrategy& s); - - std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s); - - std::istream& operator>>(std::istream& is, BoundStrategy& s); - - std::istream& operator>>(std::istream& is, TraverseStrategy& s); - - BoundStrategy bs_from_string(std::string s); - - TraverseStrategy ts_from_string(std::string s); - + inline std::ostream& operator<<(std::ostream& os, const BoundStrategy& s) + { + switch(s) { + case BoundStrategy::bruteforce : + os << "bruteforce"; + break; + case BoundStrategy::local_dual_bound : + os << "local_grob"; + break; + case BoundStrategy::local_combined : + os << "local_combined"; + break; + case BoundStrategy::local_dual_bound_refined : + os << "local_refined"; + break; + case BoundStrategy::local_dual_bound_for_each_point : + os << "local_for_each_point"; + break; + default: + os << "FORGOTTEN BOUND STRATEGY"; + } + return os; + } + + inline std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s) + { + switch(s) { + case TraverseStrategy::depth_first : + os << "DFS"; + break; + case TraverseStrategy::breadth_first : + os << "BFS"; + break; + case TraverseStrategy::breadth_first_value : + os << "BFS-VAL"; + break; + case TraverseStrategy::upper_bound : + os << "UB"; + break; + default: + os << "FORGOTTEN TRAVERSE STRATEGY"; + } + return os; + } + + inline std::istream& operator>>(std::istream& is, TraverseStrategy& s) + { + std::string ss; + is >> ss; + if (ss == "DFS") { + s = TraverseStrategy::depth_first; + } else if (ss == "BFS") { + s = TraverseStrategy::breadth_first; + } else if (ss == "BFS-VAL") { + s = TraverseStrategy::breadth_first_value; + } else if (ss == "UB") { + s = TraverseStrategy::upper_bound; + } else { + throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY"); + } + return is; + } + + + inline std::istream& operator>>(std::istream& is, BoundStrategy& s) + { + std::string ss; + is >> ss; + if (ss == "bruteforce") { + s = BoundStrategy::bruteforce; + } else if (ss == "local_grob") { + s = BoundStrategy::local_dual_bound; + } else if (ss == "local_combined") { + s = BoundStrategy::local_combined; + } else if (ss == "local_refined") { + s = BoundStrategy::local_dual_bound_refined; + } else if (ss == "local_for_each_point") { + s = BoundStrategy::local_dual_bound_for_each_point; + } else { + throw std::runtime_error("UNKNOWN BOUND STRATEGY"); + } + return is; + } + + inline BoundStrategy bs_from_string(std::string s) + { + std::stringstream ss(s); + BoundStrategy result; + ss >> result; + return result; + } + + inline TraverseStrategy ts_from_string(std::string s) + { + std::stringstream ss(s); + TraverseStrategy result; + ss >> result; + return result; + } + + template struct CalculationParams { static constexpr int ALL_DIMENSIONS = -1; @@ -75,22 +168,22 @@ namespace md { // print statistics on each quad-tree level bool print_stats { false }; -#ifdef PRINT_HEAT_MAP +#ifdef MD_PRINT_HEAT_MAP HeatMaps heat_maps; #endif }; - template + template class DistanceCalculator { - using DualBox = md::DualBox; - using CellValueVector = std::vector; + using Real = Real_; + using CellValueVector = std::vector>; public: DistanceCalculator(const DiagramProvider& a, const DiagramProvider& b, - CalculationParams& params); + CalculationParams& params); Real distance(); @@ -100,7 +193,7 @@ namespace md { DiagramProvider module_a_; DiagramProvider module_b_; - CalculationParams& params_; + CalculationParams& params_; int n_hera_calls_; std::map n_hera_calls_per_level_; @@ -112,65 +205,83 @@ namespace md { CellValueVector get_initial_dual_grid(Real& lower_bound); +#ifdef MD_PRINT_HEAT_MAP void heatmap_in_dimension(int dim, int depth); +#endif Real get_max_x(int module) const; Real get_max_y(int module) const; - void set_cell_central_value(CellWithValue& dual_cell); + void set_cell_central_value(CellWithValue& dual_cell); Real get_distance(); Real get_distance_pq(); - // temporary, to try priority queue - Real get_max_possible_value(const CellWithValue* first_cell_ptr, int n_cells); + Real get_max_possible_value(const CellWithValue* first_cell_ptr, int n_cells); - Real get_upper_bound(const CellWithValue& dual_cell, Real good_enough_upper_bound) const; + Real get_upper_bound(const CellWithValue& dual_cell, Real good_enough_upper_bound) const; - Real get_single_dgm_bound(const CellWithValue& dual_cell, ValuePoint vp, int module, + Real get_single_dgm_bound(const CellWithValue& dual_cell, ValuePoint vp, int module, Real good_enough_value) const; // this bound depends only on dual box - Real get_local_dual_bound(int module, const DualBox& dual_box) const; + Real get_local_dual_bound(int module, const DualBox& dual_box) const; - Real get_local_dual_bound(const DualBox& dual_box) const; + Real get_local_dual_bound(const DualBox& dual_box) const; // this bound depends only on dual box, is more accurate - Real get_local_refined_bound(int module, const md::DualBox& dual_box) const; + Real get_local_refined_bound(int module, const DualBox& dual_box) const; - Real get_local_refined_bound(const md::DualBox& dual_box) const; + Real get_local_refined_bound(const DualBox& dual_box) const; Real get_good_enough_upper_bound(Real lower_bound) const; - Real - get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint value_point, const Point& p) const; + Real get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint value_point, + const Point& p) const; - void check_upper_bound(const CellWithValue& dual_cell) const; + void check_upper_bound(const CellWithValue& dual_cell) const; - Real distance_on_line(DualPoint line); - Real distance_on_line_const(DualPoint line) const; + Real distance_on_line(DualPoint line); + Real distance_on_line_const(DualPoint line) const; Real current_error(Real lower_bound, Real upper_bound); }; - Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, CalculationParams& params); + template + Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, + CalculationParams& params); - Real matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b, CalculationParams& params); + template + Real matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b, + CalculationParams& params); // for upper bound experiment struct UbExperimentRecord { - Real error; - Real lower_bound; - Real upper_bound; - CellWithValue cell; + double error; + double lower_bound; + double upper_bound; + CellWithValue cell; long long int time; long long int n_hera_calls; }; - std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r); + inline std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r) + { + os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound; + return os; + } + + + template + void print_map(const std::map& dic) + { + for(const auto kv : dic) { + fmt::print("{} -> {}\n", kv.first, kv.second); + } + } -} +} // namespace md #include "matching_distance.hpp" diff --git a/matching/include/matching_distance.hpp b/matching/include/matching_distance.hpp index d2d2fbc..48c8464 100644 --- a/matching/include/matching_distance.hpp +++ b/matching/include/matching_distance.hpp @@ -1,34 +1,26 @@ namespace md { - template - void print_map(const std::map& dic) - { - for(const auto kv : dic) { - fmt::print("{} -> {}\n", kv.first, kv.second); - } - } - - template - void DistanceCalculator::check_upper_bound(const CellWithValue& dual_cell) const + template + void DistanceCalculator::check_upper_bound(const CellWithValue& dual_cell) const { spd::debug("Enter check_get_max_delta_on_cell"); const int n_samples_lambda = 100; const int n_samples_mu = 100; - DualBox db = dual_cell.dual_box(); - Real min_lambda = db.lambda_min(); - Real max_lambda = db.lambda_max(); - Real min_mu = db.mu_min(); - Real max_mu = db.mu_max(); - - Real h_lambda = (max_lambda - min_lambda) / n_samples_lambda; - Real h_mu = (max_mu - min_mu) / n_samples_mu; + DualBox db = dual_cell.dual_box(); + R min_lambda = db.lambda_min(); + R max_lambda = db.lambda_max(); + R min_mu = db.mu_min(); + R max_mu = db.mu_max(); + + R h_lambda = (max_lambda - min_lambda) / n_samples_lambda; + R h_mu = (max_mu - min_mu) / n_samples_mu; for(int i = 1; i < n_samples_lambda; ++i) { for(int j = 1; j < n_samples_mu; ++j) { - Real lambda = min_lambda + i * h_lambda; - Real mu = min_mu + j * h_mu; - DualPoint l(db.axis_type(), db.angle_type(), lambda, mu); - Real other_result = distance_on_line_const(l); - Real diff = fabs(dual_cell.stored_upper_bound() - other_result); + R lambda = min_lambda + i * h_lambda; + R mu = min_mu + j * h_mu; + DualPoint l(db.axis_type(), db.angle_type(), lambda, mu); + R other_result = distance_on_line_const(l); + R diff = fabs(dual_cell.stored_upper_bound() - other_result); if (other_result > dual_cell.stored_upper_bound()) { spd::error( "in check_upper_bound, upper_bound = {}, other_result = {}, diff = {}\ndual_cell = {}", @@ -42,10 +34,10 @@ namespace md { // for all lines l, l' inside dual box, // find the upper bound on the difference of weighted pushes of p - template - Real - DistanceCalculator::get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint vp, - const Point& p) const + template + R + DistanceCalculator::get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint vp, + const Point& p) const { assert(p.x >= 0 && p.y >= 0); @@ -53,15 +45,15 @@ namespace md { std::vector debug_ids = {3, 13, 54, 218, 350, 382, 484, 795, 2040, 8415, 44076}; bool debug = false; // std::find(debug_ids.begin(), debug_ids.end(), dual_cell.id) != debug_ids.end(); #endif - DualPoint line = dual_cell.value_point(vp); - const Real base_value = line.weighted_push(p); + DualPoint line = dual_cell.value_point(vp); + const R base_value = line.weighted_push(p); spd::debug("Enter get_max_displacement_single_point, p = {},\ndual_cell = {},\nline = {}, base_value = {}\n", p, dual_cell, line, base_value); - Real result = 0.0; - for(DualPoint dp : dual_cell.dual_box().critical_points(p)) { - Real dp_value = dp.weighted_push(p); + R result = 0.0; + for(DualPoint dp : dual_cell.dual_box().critical_points(p)) { + R dp_value = dp.weighted_push(p); spd::debug( "In get_max_displacement_single_point, p = {}, critical dp = {},\ndp_value = {}, diff = {},\ndual_cell = {}\n", p, dp, dp_value, fabs(base_value - dp_value), dual_cell); @@ -69,15 +61,15 @@ namespace md { } #ifdef MD_DO_FULL_CHECK - DualBox db = dual_cell.dual_box(); - std::uniform_real_distribution dlambda(db.lambda_min(), db.lambda_max()); - std::uniform_real_distribution dmu(db.mu_min(), db.mu_max()); + auto db = dual_cell.dual_box(); + std::uniform_real_distribution dlambda(db.lambda_min(), db.lambda_max()); + std::uniform_real_distribution dmu(db.mu_min(), db.mu_max()); std::mt19937 gen(1); for(int i = 0; i < 1000; ++i) { - Real lambda = dlambda(gen); - Real mu = dmu(gen); - DualPoint dp_random { db.axis_type(), db.angle_type(), lambda, mu }; - Real dp_value = dp_random.weighted_push(p); + R lambda = dlambda(gen); + R mu = dmu(gen); + DualPoint dp_random { db.axis_type(), db.angle_type(), lambda, mu }; + R dp_value = dp_random.weighted_push(p); if (fabs(base_value - dp_value) > result) { spd::error("in get_max_displacement_single_point, p = {}, vp = {}\ndb = {}\nresult = {}, base_value = {}, dp_value = {}, dp_random = {}", p, vp, db, result, base_value, dp_value, dp_random); @@ -89,12 +81,12 @@ namespace md { return result; } - template - typename DistanceCalculator::CellValueVector DistanceCalculator::get_initial_dual_grid(Real& lower_bound) + template + typename DistanceCalculator::CellValueVector DistanceCalculator::get_initial_dual_grid(R& lower_bound) { CellValueVector result = get_refined_grid(params_.initialization_depth, false, true); - lower_bound = -1.0; + lower_bound = -1; for(const auto& dc : result) { lower_bound = std::max(lower_bound, dc.max_corner_value()); } @@ -102,8 +94,8 @@ namespace md { assert(lower_bound >= 0); for(auto& dual_cell : result) { - Real good_enough_ub = get_good_enough_upper_bound(lower_bound); - Real max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub); + R good_enough_ub = get_good_enough_upper_bound(lower_bound); + R max_value_on_cell = get_upper_bound(dual_cell, good_enough_ub); dual_cell.set_max_possible_value(max_value_on_cell); #ifdef MD_DO_FULL_CHECK @@ -116,39 +108,39 @@ namespace md { return result; } - template - typename DistanceCalculator::CellValueVector - DistanceCalculator::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last) + template + typename DistanceCalculator::CellValueVector + DistanceCalculator::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last) { - const Real y_max = std::max(module_a_.max_y(), module_b_.max_y()); - const Real x_max = std::max(module_a_.max_x(), module_b_.max_x()); + const R y_max = std::max(module_a_.max_y(), module_b_.max_y()); + const R x_max = std::max(module_a_.max_x(), module_b_.max_x()); - const Real lambda_min = 0; - const Real lambda_max = 1; + const R lambda_min = 0; + const R lambda_max = 1; - const Real mu_min = 0; + const R mu_min = 0; - DualBox x_flat(DualPoint(AxisType::x_type, AngleType::flat, lambda_min, mu_min), - DualPoint(AxisType::x_type, AngleType::flat, lambda_max, x_max)); + DualBox x_flat(DualPoint(AxisType::x_type, AngleType::flat, lambda_min, mu_min), + DualPoint(AxisType::x_type, AngleType::flat, lambda_max, x_max)); - DualBox x_steep(DualPoint(AxisType::x_type, AngleType::steep, lambda_min, mu_min), - DualPoint(AxisType::x_type, AngleType::steep, lambda_max, x_max)); + DualBox x_steep(DualPoint(AxisType::x_type, AngleType::steep, lambda_min, mu_min), + DualPoint(AxisType::x_type, AngleType::steep, lambda_max, x_max)); - DualBox y_flat(DualPoint(AxisType::y_type, AngleType::flat, lambda_min, mu_min), - DualPoint(AxisType::y_type, AngleType::flat, lambda_max, y_max)); + DualBox y_flat(DualPoint(AxisType::y_type, AngleType::flat, lambda_min, mu_min), + DualPoint(AxisType::y_type, AngleType::flat, lambda_max, y_max)); - DualBox y_steep(DualPoint(AxisType::y_type, AngleType::steep, lambda_min, mu_min), - DualPoint(AxisType::y_type, AngleType::steep, lambda_max, y_max)); + DualBox y_steep(DualPoint(AxisType::y_type, AngleType::steep, lambda_min, mu_min), + DualPoint(AxisType::y_type, AngleType::steep, lambda_max, y_max)); - CellWithValue x_flat_cell(x_flat, 0); - CellWithValue x_steep_cell(x_steep, 0); - CellWithValue y_flat_cell(y_flat, 0); - CellWithValue y_steep_cell(y_steep, 0); + CellWithValue x_flat_cell(x_flat, 0); + CellWithValue x_steep_cell(x_steep, 0); + CellWithValue y_flat_cell(y_flat, 0); + CellWithValue y_steep_cell(y_steep, 0); if (init_depth == 0) { - DualPoint diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0); + DualPoint diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0); - Real diagonal_value = distance_on_line(diagonal_x_flat); + R diagonal_value = distance_on_line(diagonal_x_flat); n_hera_calls_per_level_[0]++; x_flat_cell.set_value_at(ValuePoint::lower_right, diagonal_value); @@ -162,7 +154,7 @@ namespace md { x_steep_cell.id = 2; y_flat_cell.id = 3; y_steep_cell.id = 4; - CellWithValue::max_id = 4; + CellWithValue::max_id = 4; #endif CellValueVector result {x_flat_cell, x_steep_cell, y_flat_cell, y_steep_cell}; @@ -189,10 +181,10 @@ namespace md { return result; } - template - DistanceCalculator::DistanceCalculator(const T& a, + template + DistanceCalculator::DistanceCalculator(const T& a, const T& b, - CalculationParams& params) + CalculationParams& params) : module_a_(a), module_b_(b), @@ -213,33 +205,33 @@ namespace md { module_a_.max_x(), module_a_.max_y(), module_b_.max_x(), module_b_.max_y()); } - template - Real DistanceCalculator::get_max_x(int module) const + template + R DistanceCalculator::get_max_x(int module) const { return (module == 0) ? module_a_.max_x() : module_b_.max_x(); } - template - Real DistanceCalculator::get_max_y(int module) const + template + R DistanceCalculator::get_max_y(int module) const { return (module == 0) ? module_a_.max_y() : module_b_.max_y(); } - template - Real - DistanceCalculator::get_local_refined_bound(const md::DualBox& dual_box) const + template + R + DistanceCalculator::get_local_refined_bound(const DualBox& dual_box) const { return get_local_refined_bound(0, dual_box) + get_local_refined_bound(1, dual_box); } - template - Real - DistanceCalculator::get_local_refined_bound(int module, const md::DualBox& dual_box) const + template + R + DistanceCalculator::get_local_refined_bound(int module, const DualBox& dual_box) const { spd::debug("Enter get_local_refined_bound, dual_box = {}", dual_box); - Real d_lambda = dual_box.lambda_max() - dual_box.lambda_min(); - Real d_mu = dual_box.mu_max() - dual_box.mu_min(); - Real result; + R d_lambda = dual_box.lambda_max() - dual_box.lambda_min(); + R d_mu = dual_box.mu_max() - dual_box.mu_min(); + R result; if (dual_box.axis_type() == AxisType::x_type) { if (dual_box.is_flat()) { result = dual_box.lambda_max() * d_mu + (get_max_x(module) - dual_box.mu_min()) * d_lambda; @@ -258,11 +250,11 @@ namespace md { return result; } - template - Real DistanceCalculator::get_local_dual_bound(int module, const md::DualBox& dual_box) const + template + R DistanceCalculator::get_local_dual_bound(int module, const DualBox& dual_box) const { - Real dlambda = dual_box.lambda_max() - dual_box.lambda_min(); - Real dmu = dual_box.mu_max() - dual_box.mu_min(); + R dlambda = dual_box.lambda_max() - dual_box.lambda_min(); + R dmu = dual_box.mu_max() - dual_box.mu_min(); if (dual_box.is_flat()) { return get_max_x(module) * dlambda + dmu; @@ -271,20 +263,20 @@ namespace md { } } - template - Real DistanceCalculator::get_local_dual_bound(const md::DualBox& dual_box) const + template + R DistanceCalculator::get_local_dual_bound(const DualBox& dual_box) const { return get_local_dual_bound(0, dual_box) + get_local_dual_bound(1, dual_box); } - template - Real DistanceCalculator::get_upper_bound(const CellWithValue& dual_cell, Real good_enough_ub) const + template + R DistanceCalculator::get_upper_bound(const CellWithValue& dual_cell, R good_enough_ub) const { assert(good_enough_ub >= 0); switch(params_.bound_strategy) { case BoundStrategy::bruteforce: - return std::numeric_limits::max(); + return std::numeric_limits::max(); case BoundStrategy::local_dual_bound: return dual_cell.min_value() + get_local_dual_bound(dual_cell.dual_box()); @@ -293,7 +285,7 @@ namespace md { return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); case BoundStrategy::local_combined: { - Real cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); + R cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); if (cheap_upper_bound < good_enough_ub) { return cheap_upper_bound; } else { @@ -302,14 +294,14 @@ namespace md { } case BoundStrategy::local_dual_bound_for_each_point: { - Real result = std::numeric_limits::max(); + R result = std::numeric_limits::max(); for(ValuePoint vp : k_corner_vps) { if (not dual_cell.has_value_at(vp)) { continue; } - Real base_value = dual_cell.value_at(vp); - Real bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub); + R base_value = dual_cell.value_at(vp); + R bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, good_enough_ub); if (params_.stop_asap and bound_dgm_a + base_value >= good_enough_ub) { // we want to return a valid upper bound, not just something that will prevent discarding the cell @@ -318,8 +310,8 @@ namespace md { return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); } - Real bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1, - std::max(Real(0), good_enough_ub - bound_dgm_a)); + R bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1, + std::max(R(0), good_enough_ub - bound_dgm_a)); result = std::min(result, base_value + bound_dgm_a + bound_dgm_b); @@ -336,19 +328,19 @@ namespace md { } } // to suppress compiler warning - return std::numeric_limits::max(); + return std::numeric_limits::max(); } // find maximal displacement of weighted points of m for all lines in dual_box - template - Real - DistanceCalculator::get_single_dgm_bound(const CellWithValue& dual_cell, + template + R + DistanceCalculator::get_single_dgm_bound(const CellWithValue& dual_cell, ValuePoint vp, int module, - [[maybe_unused]] Real good_enough_value) const + R good_enough_value) const { - Real result = 0; - Point max_point; + R result = 0; + Point max_point; spd::debug( "Enter get_single_dgm_bound, module = {}, dual_cell = {}, vp = {}, good_enough_value = {}, stop_asap = {}\n", @@ -358,7 +350,7 @@ namespace md { for(const auto& position : m.positions()) { spd::debug("in get_single_dgm_bound, simplex = {}\n", position); - Real x = get_max_displacement_single_point(dual_cell, vp, position); + R x = get_max_displacement_single_point(dual_cell, vp, position); spd::debug("In get_single_dgm_bound, point = {}, displacement = {}", position, x); @@ -385,30 +377,30 @@ namespace md { return result; } - template - Real DistanceCalculator::distance() + template + R DistanceCalculator::distance() { return get_distance_pq(); } // calculate weighted bottleneneck distance between slices on line // increments hera calls counter - template - Real DistanceCalculator::distance_on_line(DualPoint line) + template + R DistanceCalculator::distance_on_line(DualPoint line) { ++n_hera_calls_; - Real result = distance_on_line_const(line); + R result = distance_on_line_const(line); return result; } - template - Real DistanceCalculator::distance_on_line_const(DualPoint line) const + template + R DistanceCalculator::distance_on_line_const(DualPoint line) const { // TODO: think about this - how to call Hera auto dgm_a = module_a_.weighted_slice_diagram(line); auto dgm_b = module_b_.weighted_slice_diagram(line); - Real result; - if (params_.hera_epsilon > static_cast(0)) { + R result; + if (params_.hera_epsilon > static_cast(0)) { result = hera::bottleneckDistApprox(dgm_a, dgm_b, params_.hera_epsilon) / ( params_.hera_epsilon + 1); } else { result = hera::bottleneckDistExact(dgm_a, dgm_b); @@ -423,10 +415,10 @@ namespace md { return result; } - template - Real DistanceCalculator::get_good_enough_upper_bound(Real lower_bound) const + template + R DistanceCalculator::get_good_enough_upper_bound(R lower_bound) const { - Real result; + R result; // in upper_bound strategy we only prune cells if they cannot improve the lower bound, // otherwise the experiment is supposed to run indefinitely if (params_.traverse_strategy == TraverseStrategy::upper_bound) { @@ -440,14 +432,14 @@ namespace md { // helper function // calculate weighted bt distance on cell center, // assign distance value to cell, keep it in heat_map, and return - template - void DistanceCalculator::set_cell_central_value(CellWithValue& dual_cell) + template + void DistanceCalculator::set_cell_central_value(CellWithValue& dual_cell) { - DualPoint central_line {dual_cell.center()}; + DualPoint central_line {dual_cell.center()}; spd::debug("In set_cell_central_value, processing dual cell = {}, line = {}", dual_cell.dual_box(), central_line); - Real new_value = distance_on_line(central_line); + R new_value = distance_on_line(central_line); n_hera_calls_per_level_[dual_cell.level() + 1]++; dual_cell.set_value_at(ValuePoint::center, new_value); params_.actual_max_depth = std::max(params_.actual_max_depth, dual_cell.level() + 1); @@ -472,10 +464,10 @@ namespace md { // assumes that the underlying container is vector! // cell_ptr: pointer to the first element in queue // n_cells: queue size - template - Real DistanceCalculator::get_max_possible_value(const CellWithValue* cell_ptr, int n_cells) + template + R DistanceCalculator::get_max_possible_value(const CellWithValue* cell_ptr, int n_cells) { - Real result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0; + R result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0; for(int i = 0; i < n_cells; ++i, ++cell_ptr) { result = std::max(result, cell_ptr->stored_upper_bound()); } @@ -485,11 +477,11 @@ namespace md { // helper function: // return current error from lower and upper bounds // and save it in params_ (hence not const) - template - Real DistanceCalculator::current_error(Real lower_bound, Real upper_bound) + template + R DistanceCalculator::current_error(R lower_bound, R upper_bound) { - Real current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound - : std::numeric_limits::max(); + R current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound + : std::numeric_limits::max(); params_.actual_error = current_error; @@ -505,8 +497,8 @@ namespace md { // use priority queue to store dual cells // comparison function depends on the strategies in params_ // ressets hera calls counter - template - Real DistanceCalculator::get_distance_pq() + template + R DistanceCalculator::get_distance_pq() { std::map n_cells_considered; std::map n_cells_pushed_into_queue; @@ -527,26 +519,26 @@ namespace md { // if cell is too deep and is not pushed into queue, // we still need to take its max value into account; // the max over such cells is stored in max_result_on_too_fine_cells - Real upper_bound_on_deep_cells = -1; + R upper_bound_on_deep_cells = -1; spd::debug("Started iterations in dual space, delta = {}, bound_strategy = {}", params_.delta, params_.bound_strategy); // user-defined less lambda function // to regulate priority queue depending on strategy - auto dual_cell_less = [this](const CellWithValue& a, const CellWithValue& b) { + auto dual_cell_less = [this](const CellWithValue& a, const CellWithValue& b) { int a_level = a.level(); int b_level = b.level(); - Real a_value = a.max_corner_value(); - Real b_value = b.max_corner_value(); - Real a_ub = a.stored_upper_bound(); - Real b_ub = b.stored_upper_bound(); + R a_value = a.max_corner_value(); + R b_value = b.max_corner_value(); + R a_ub = a.stored_upper_bound(); + R b_ub = b.stored_upper_bound(); if (this->params_.traverse_strategy == TraverseStrategy::upper_bound and (not a.has_max_possible_value() or not b.has_max_possible_value())) { throw std::runtime_error("no upper bound on cell"); } - DualPoint a_lower_left = a.dual_box().lower_left(); - DualPoint b_lower_left = b.dual_box().lower_left(); + DualPoint a_lower_left = a.dual_box().lower_left(); + DualPoint b_lower_left = b.dual_box().lower_left(); switch(this->params_.traverse_strategy) { // in both breadth_first searches we want coarser cells @@ -569,24 +561,24 @@ namespace md { } }; - std::priority_queue dual_cells_queue( + std::priority_queue, CellValueVector, decltype(dual_cell_less)> dual_cells_queue( dual_cell_less); // weighted bt distance on the center of current cell - Real lower_bound = std::numeric_limits::min(); + R lower_bound = std::numeric_limits::min(); // init pq and lower bound for(auto& init_cell : get_initial_dual_grid(lower_bound)) { dual_cells_queue.push(init_cell); } - Real upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size()); + R upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size()); std::vector ub_experiment_results; while(not dual_cells_queue.empty()) { - CellWithValue dual_cell = dual_cells_queue.top(); + CellWithValue dual_cell = dual_cells_queue.top(); dual_cells_queue.pop(); assert(dual_cell.has_corner_value() and dual_cell.has_max_possible_value() @@ -620,7 +612,7 @@ namespace md { // until now, dual_cell knows its value in one of its corners // new_value will be the weighted distance at its center set_cell_central_value(dual_cell); - Real new_value = dual_cell.value_at(ValuePoint::center); + R new_value = dual_cell.value_at(ValuePoint::center); lower_bound = std::max(new_value, lower_bound); spd::debug("Processed cell = {}, weighted value = {}, lower_bound = {}", dual_cell, new_value, lower_bound); @@ -638,11 +630,11 @@ namespace md { throw std::runtime_error("no value on cell"); // if delta is smaller than good_enough_value, it allows to prune cell - Real good_enough_ub = get_good_enough_upper_bound(lower_bound); + R good_enough_ub = get_good_enough_upper_bound(lower_bound); // upper bound of the parent holds for refined_cell // and can sometimes be smaller! - Real upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(), + R upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(), get_upper_bound(refined_cell, good_enough_ub)); spd::debug("upper_bound_on_refined_cell = {}, dual_cell.stored_upper_bound = {}, get_upper_bound = {}", @@ -774,10 +766,46 @@ namespace md { return lower_bound; } - template - int DistanceCalculator::get_hera_calls_number() const + template + int DistanceCalculator::get_hera_calls_number() const { return n_hera_calls_; } -} \ No newline at end of file + template + R matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, + CalculationParams& params) + { + R result; + // compute distance only in one dimension + if (params.dim != CalculationParams::ALL_DIMENSIONS) { + BifiltrationProxy bifp_a(bif_a, params.dim); + BifiltrationProxy bifp_b(bif_b, params.dim); + DistanceCalculator> runner(bifp_a, bifp_b, params); + result = runner.distance(); + params.n_hera_calls = runner.get_hera_calls_number(); + } else { + // compute distance in all dimensions, return maximal + result = -1; + for(int dim = 0; dim < std::max(bif_a.maximal_dim(), bif_b.maximal_dim()); ++dim) { + BifiltrationProxy bifp_a(bif_a, params.dim); + BifiltrationProxy bifp_b(bif_a, params.dim); + DistanceCalculator> runner(bifp_a, bifp_b, params); + result = std::max(result, runner.distance()); + params.n_hera_calls += runner.get_hera_calls_number(); + } + } + return result; + } + + + template + R matching_distance(const ModulePresentation& mod_a, const ModulePresentation& mod_b, + CalculationParams& params) + { + DistanceCalculator> runner(mod_a, mod_b, params); + R result = runner.distance(); + params.n_hera_calls = runner.get_hera_calls_number(); + return result; + } +} // namespace md diff --git a/matching/include/persistence_module.h b/matching/include/persistence_module.h index a1fc67e..e99771f 100644 --- a/matching/include/persistence_module.h +++ b/matching/include/persistence_module.h @@ -5,6 +5,12 @@ #include #include #include +#include +#include +#include + +#include "phat/boundary_matrix.h" +#include "phat/compute_persistence_pairs.h" #include "common_util.h" #include "dual_point.h" @@ -28,17 +34,20 @@ namespace md { */ + template class ModulePresentation { public: + using RealVec = std::vector; + enum Format { rivet_firep }; struct Relation { - Point position_; + Point position_; IndexVec components_; Relation() {} - Relation(const Point& _pos, const IndexVec& _components); + Relation(const Point& _pos, const IndexVec& _components); Real get_x() const { return position_.x; } Real get_y() const { return position_.y; } @@ -48,9 +57,9 @@ namespace md { ModulePresentation() {} - ModulePresentation(const PointVec& _generators, const RelVec& _relations); + ModulePresentation(const PointVec& _generators, const RelVec& _relations); - Diagram weighted_slice_diagram(const DualPoint& line) const; + Diagram weighted_slice_diagram(const DualPoint& line) const; // translate all points by vector (a,a) void translate(Real a); @@ -59,9 +68,7 @@ namespace md { Real minimal_coordinate() const { return std::min(min_x(), min_y()); } // return box that contains all positions of all simplices - Box bounding_box() const; - - friend std::ostream& operator<<(std::ostream& os, const ModulePresentation& mp); + Box bounding_box() const; Real max_x() const { return max_x_; } @@ -71,26 +78,27 @@ namespace md { Real min_y() const { return min_y_; } - PointVec positions() const; + PointVec positions() const; private: - PointVec generators_; + PointVec generators_; std::vector relations_; - PointVec positions_; + PointVec positions_; Real max_x_ { std::numeric_limits::max() }; Real max_y_ { std::numeric_limits::max() }; Real min_x_ { -std::numeric_limits::max() }; Real min_y_ { -std::numeric_limits::max() }; - Box bounding_box_; + Box bounding_box_; void init_boundaries(); - void project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; - void project_relations(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; + void project_generators(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; + void project_relations(const DualPoint& slice, IndexVec& sorted_indices, RealVec& projections) const; }; } // namespace md +#include "persistence_module.hpp" #endif //MATCHING_DISTANCE_PERSISTENCE_MODULE_H diff --git a/matching/include/persistence_module.hpp b/matching/include/persistence_module.hpp new file mode 100644 index 0000000..6e49b2e --- /dev/null +++ b/matching/include/persistence_module.hpp @@ -0,0 +1,177 @@ +namespace md { + + /** + * + * @param values vector of length n + * @return [a_1,...,a_n] such that + * 1) values[a_1] <= values[a_2] <= ... <= values[a_n] + * 2) a_1,...,a_n is a permutation of 1,..,n + */ + + template + IndexVec get_sorted_indices(const std::vector& values) + { + IndexVec result(values.size()); + std::iota(result.begin(), result.end(), 0); + std::sort(result.begin(), result.end(), + [&values](size_t a, size_t b) { return values[a] < values[b]; }); + return result; + } + + // helper function to initialize const member positions_ in ModulePresentation + template + PointVec concat_gen_and_rel_positions(const PointVec& generators, + const typename ModulePresentation::RelVec& relations) + { + std::unordered_set> ps(generators.begin(), generators.end()); + for(const auto& rel : relations) { + ps.insert(rel.position_); + } + return PointVec(ps.begin(), ps.end()); + } + + + template + void ModulePresentation::init_boundaries() + { + max_x_ = std::numeric_limits::max(); + max_y_ = std::numeric_limits::max(); + min_x_ = -std::numeric_limits::max(); + min_y_ = -std::numeric_limits::max(); + + for(const auto& gen : positions_) { + min_x_ = std::min(gen.x, min_x_); + min_y_ = std::min(gen.y, min_y_); + max_x_ = std::max(gen.x, max_x_); + max_y_ = std::max(gen.y, max_y_); + } + + bounding_box_ = Box(Point(min_x_, min_y_), Point(max_x_, max_y_)); + } + + + template + ModulePresentation::ModulePresentation(const PointVec& _generators, const RelVec& _relations) : + generators_(_generators), + relations_(_relations) + { + init_boundaries(); + } + + template + void ModulePresentation::translate(Real a) + { + for(auto& g : generators_) { + g.translate(a); + } + + for(auto& r : relations_) { + r.position_.translate(a); + } + + positions_ = concat_gen_and_rel_positions(generators_, relations_); + init_boundaries(); + } + + + /** + * + * @param slice line on which generators are projected + * @param sorted_indices [a_1,...,a_n] s.t. wpush(generator[a_1]) <= wpush(generator[a_2]) <= .. + * @param projections sorted weighted pushes of generators + */ + + template + void ModulePresentation::project_generators(const DualPoint& slice, + IndexVec& sorted_indices, RealVec& projections) const + { + size_t num_gens = generators_.size(); + + RealVec gen_values; + gen_values.reserve(num_gens); + for(const auto& pos : generators_) { + gen_values.push_back(slice.weighted_push(pos)); + } + sorted_indices = get_sorted_indices(gen_values); + projections.clear(); + projections.reserve(num_gens); + for(auto i : sorted_indices) { + projections.push_back(gen_values[i]); + } + } + + template + void ModulePresentation::project_relations(const DualPoint& slice, IndexVec& sorted_rel_indices, + RealVec& projections) const + { + size_t num_rels = relations_.size(); + + RealVec rel_values; + rel_values.reserve(num_rels); + for(const auto& rel : relations_) { + rel_values.push_back(slice.weighted_push(rel.position_)); + } + sorted_rel_indices = get_sorted_indices(rel_values); + projections.clear(); + projections.reserve(num_rels); + for(auto i : sorted_rel_indices) { + projections.push_back(rel_values[i]); + } + } + + template + Diagram ModulePresentation::weighted_slice_diagram(const DualPoint& slice) const + { + IndexVec sorted_gen_indices, sorted_rel_indices; + RealVec gen_projections, rel_projections; + + project_generators(slice, sorted_gen_indices, gen_projections); + project_relations(slice, sorted_rel_indices, rel_projections); + + phat::boundary_matrix<> phat_matrix; + + phat_matrix.set_num_cols(relations_.size()); + + for(Index i = 0; i < (Index) relations_.size(); i++) { + IndexVec current_relation = relations_[sorted_rel_indices[i]].components_; + for(auto& j : current_relation) { + j = sorted_gen_indices[j]; + } + std::sort(current_relation.begin(), current_relation.end()); + phat_matrix.set_dim(i, current_relation.size()); + phat_matrix.set_col(i, current_relation); + } + + phat::persistence_pairs phat_persistence_pairs; + phat::compute_persistence_pairs(phat_persistence_pairs, phat_matrix); + + Diagram dgm; + + constexpr Real real_inf = std::numeric_limits::infinity(); + + for(Index i = 0; i < (Index) phat_persistence_pairs.get_num_pairs(); i++) { + std::pair new_pair = phat_persistence_pairs.get_pair(i); + bool is_finite_pair = new_pair.second != phat::k_infinity_index; + Real birth = gen_projections.at(new_pair.first); + Real death = is_finite_pair ? rel_projections.at(new_pair.second) : real_inf; + if (birth != death) { + dgm.emplace_back(birth, death); + } + } + + return dgm; + } + + template + PointVec ModulePresentation::positions() const + { + return positions_; + } + + template + Box ModulePresentation::bounding_box() const + { + return bounding_box_; + } + +} // namespace md diff --git a/matching/include/simplex.h b/matching/include/simplex.h index e9d0e30..75bbcae 100644 --- a/matching/include/simplex.h +++ b/matching/include/simplex.h @@ -9,6 +9,7 @@ namespace md { + template class Bifiltration; enum class BifiltrationFormat { @@ -38,11 +39,21 @@ namespace md { int dim() const { return vertices_.size() - 1; } - void push_back(int v); + void push_back(int v) + { + vertices_.push_back(v); + std::sort(vertices_.begin(), vertices_.end()); + } AbstractSimplex() { } - AbstractSimplex(std::vector vertices, bool sort = true); + AbstractSimplex(std::vector vertices, bool sort = true) + :vertices_(vertices) + { + if (sort) + std::sort(vertices_.begin(), vertices_.end()); + } + template AbstractSimplex(Iter beg_iter, Iter end_iter, bool sort = true) @@ -53,22 +64,51 @@ namespace md { std::sort(vertices_.begin(), end()); } - std::vector facets() const; + std::vector facets() const + { + std::vector result; + for (int i = 0; i < static_cast(vertices_.size()); ++i) { + std::vector facet_vertices; + facet_vertices.reserve(dim()); + for (int j = 0; j < static_cast(vertices_.size()); ++j) { + if (j != i) + facet_vertices.push_back(vertices_[j]); + } + if (!facet_vertices.empty()) { + result.emplace_back(facet_vertices, false); + } + } + return result; + } friend std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s); - // compare by vertices_ only friend bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2); friend bool operator<(const AbstractSimplex&, const AbstractSimplex&); }; - std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s); + inline std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s) + { + os << "AbstractSimplex(id = " << s.id << ", vertices_ = " << container_to_string(s.vertices_) << ")"; + return os; + } + + inline bool operator<(const AbstractSimplex& a, const AbstractSimplex& b) + { + return a.vertices_ < b.vertices_; + } + + inline bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2) + { + return s1.vertices_ == s2.vertices_; + } + template class Simplex { private: Index id_; - Point pos_; + Point pos_; int dim_; // in our format we use facet indices, // this is the fastest representation for homology @@ -77,11 +117,11 @@ namespace md { // conversion routines are in Bifiltration Column facet_indices_; Column vertices_; - Real v {0.0}; // used when constructed a filtration for a slice + Real v {0}; // used when constructed a filtration for a slice public: Simplex(Index _id, std::string s, BifiltrationFormat input_format); - Simplex(Index _id, Point birth, int _dim, const Column& _bdry); + Simplex(Index _id, Point birth, int _dim, const Column& _bdry); void init_rivet(std::string s); @@ -96,9 +136,9 @@ namespace md { Real value() const { return v; } // assumes 1-criticality - Point position() const { return pos_; } + Point position() const { return pos_; } - void set_position(const Point& new_pos) { pos_ = new_pos; } + void set_position(const Point& new_pos) { pos_ = new_pos; } void scale(Real lambda) { @@ -110,12 +150,14 @@ namespace md { void set_value(Real new_val) { v = new_val; } - friend std::ostream& operator<<(std::ostream& os, const Simplex& s); - - friend Bifiltration; + friend Bifiltration; }; - std::ostream& operator<<(std::ostream& os, const Simplex& s); + template + std::ostream& operator<<(std::ostream& os, const Simplex& s); } + +#include "simplex.hpp" + #endif //MATCHING_DISTANCE_SIMPLEX_H diff --git a/matching/include/simplex.hpp b/matching/include/simplex.hpp new file mode 100644 index 0000000..ce0e30f --- /dev/null +++ b/matching/include/simplex.hpp @@ -0,0 +1,79 @@ +namespace md { + + template + Simplex::Simplex(Index id, Point birth, int dim, const Column& bdry) + : + id_(id), + pos_(birth), + dim_(dim), + facet_indices_(bdry) { } + + template + void Simplex::translate(Real a) + { + pos_.translate(a); + } + + template + void Simplex::init_rivet(std::string s) + { + auto delim_pos = s.find_first_of(";"); + assert(delim_pos > 0); + std::string vertices_str = s.substr(0, delim_pos); + std::string pos_str = s.substr(delim_pos + 1); + assert(not vertices_str.empty() and not pos_str.empty()); + // get vertices + std::stringstream vertices_ss(vertices_str); + int dim = 0; + int vertex; + while (vertices_ss >> vertex) { + dim++; + vertices_.push_back(vertex); + } + // + std::sort(vertices_.begin(), vertices_.end()); + assert(dim > 0); + + std::stringstream pos_ss(pos_str); + // TODO: get rid of 1-criticaltiy assumption + pos_ss >> pos_.x >> pos_.y; + } + + template + void Simplex::init_phat_like(std::string s) + { + facet_indices_.clear(); + std::stringstream ss(s); + ss >> dim_ >> pos_.x >> pos_.y; + if (dim_ > 0) { + facet_indices_.reserve(dim_ + 1); + for (int j = 0; j <= dim_; j++) { + Index k; + ss >> k; + facet_indices_.push_back(k); + } + } + } + + template + Simplex::Simplex(Index _id, std::string s, BifiltrationFormat input_format) + :id_(_id) + { + switch (input_format) { + case BifiltrationFormat::phat_like : + init_phat_like(s); + break; + case BifiltrationFormat::rivet : + init_rivet(s); + break; + } + } + + template + std::ostream& operator<<(std::ostream& os, const Simplex& x) + { + os << "Simplex(id = " << x.id() << ", dim = " << x.dim(); + os << ", boundary = " << container_to_string(x.boundary()) << ", pos = " << x.position() << ")"; + return os; + } +} -- cgit v1.2.3