diff options
Diffstat (limited to 'matching/src')
-rw-r--r-- | matching/src/bifiltration.cpp | 309 | ||||
-rw-r--r-- | matching/src/box.cpp | 61 | ||||
-rw-r--r-- | matching/src/cell_with_value.cpp | 247 | ||||
-rw-r--r-- | matching/src/common_util.cpp | 243 | ||||
-rw-r--r-- | matching/src/dual_box.cpp | 194 | ||||
-rw-r--r-- | matching/src/dual_point.cpp | 282 | ||||
-rw-r--r-- | matching/src/main.cpp | 367 | ||||
-rw-r--r-- | matching/src/matching_distance.cpp | 907 | ||||
-rw-r--r-- | matching/src/persistence_module.cpp | 104 | ||||
-rw-r--r-- | matching/src/simplex.cpp | 123 | ||||
-rw-r--r-- | matching/src/test_generator.cpp | 208 | ||||
-rw-r--r-- | matching/src/tests/prism_1_lesnick.bif | 27 | ||||
-rw-r--r-- | matching/src/tests/prism_2_lesnick.bif | 28 | ||||
-rw-r--r-- | matching/src/tests/test_bifiltration.cpp | 36 | ||||
l--------- | matching/src/tests/test_bifiltration_1.txt | 1 | ||||
l--------- | matching/src/tests/test_bifiltration_full_triangle_rene.txt | 1 | ||||
-rw-r--r-- | matching/src/tests/test_common.cpp | 190 | ||||
-rw-r--r-- | matching/src/tests/test_list.txt | 1 | ||||
-rw-r--r-- | matching/src/tests/test_matching_distance.cpp | 146 | ||||
-rw-r--r-- | matching/src/tests/tests_main.cpp | 2 |
20 files changed, 3477 insertions, 0 deletions
diff --git a/matching/src/bifiltration.cpp b/matching/src/bifiltration.cpp new file mode 100644 index 0000000..429a9a8 --- /dev/null +++ b/matching/src/bifiltration.cpp @@ -0,0 +1,309 @@ +#include <iostream> +#include <fstream> +#include <sstream> +#include <cassert> + +#include<phat/boundary_matrix.h> +#include<phat/compute_persistence_pairs.h> + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/fmt.h" +#include "spdlog/fmt/ostr.h" + +#include "common_util.h" +#include "bifiltration.h" + +namespace spd = spdlog; + +namespace md { + + void Bifiltration::init() + { + Point lower_left = max_point(); + Point upper_right = min_point(); + for(const auto& simplex : simplices_) { + lower_left = greatest_lower_bound(lower_left, simplex.position()); + upper_right = least_upper_bound(upper_right, simplex.position()); + maximal_dim_ = std::max(maximal_dim_, simplex.dim()); + } + bounding_box_ = Box(lower_left, upper_right); + } + + Bifiltration::Bifiltration(const std::string& fname, BifiltrationFormat input_format) + { + std::ifstream ifstr{fname.c_str()}; + if (!ifstr.good()) { + std::string error_message = fmt::format("Cannot open file {0}", fname); + std::cerr << error_message << std::endl; + throw std::runtime_error(error_message); + } + + switch(input_format) { + case BifiltrationFormat::rivet : + rivet_format_reader(ifstr); + break; + case BifiltrationFormat::rene : + rene_format_reader(ifstr); + break; + } + init(); + } + + void Bifiltration::rivet_format_reader(std::ifstream& ifstr) + { + std::string s; + std::getline(ifstr, s); + assert(s == std::string("bifiltration")); + std::getline(ifstr, parameter_1_name_); + std::getline(ifstr, parameter_2_name_); + + Index index = 0; + while(std::getline(ifstr, s)) { + if (not ignore_line(s)) + simplices_.emplace_back(index++, s, BifiltrationFormat::rivet); + } + } + + void Bifiltration::rene_format_reader(std::ifstream& ifstr) + { + spd::debug("Enter rene_format_reader"); + // read stream line by line; do not use >> operator + std::string s; + std::getline(ifstr, s); + long n_simplices = std::stol(s); + + for(Index index = 0; index < n_simplices; index++) { + std::getline(ifstr, s); + simplices_.emplace_back(index, s, BifiltrationFormat::rene); + } + spd::debug("Read {} simplices from file", n_simplices); + } + + void Bifiltration::scale(Real lambda) + { + for(auto& s : simplices_) { + s.scale(lambda); + } + init(); + } + + void Bifiltration::sanity_check() const + { +#ifdef DEBUG + spd::debug("Enter Bifiltration::sanity_check"); + // check that boundary has correct number of simplices, + // each bounding simplex has correct dim + // and appears in the filtration before the simplex it bounds + for(const auto& s : simplices_) { + assert(s.dim() >= 0); + assert(s.dim() == 0 or s.dim() + 1 == (int) s.boundary().size()); + for(auto bdry_idx : s.boundary()) { + Simplex bdry_simplex = simplices()[bdry_idx]; + assert(bdry_simplex.dim() == s.dim() - 1); + assert(bdry_simplex.position().is_less(s.position(), false)); + } + } + spd::debug("Exit Bifiltration::sanity_check"); +#endif + } + + DiagramKeeper Bifiltration::weighted_slice_diagram(const DualPoint& line, int /*dim*/) const + { + DiagramKeeper dgm; + + // make a copy for now; I want slice_diagram to be const + std::vector<Simplex> simplices(simplices_); + +// std::vector<Simplex> simplices; +// simplices.reserve(simplices_.size() / 2); +// for(const auto& s : simplices_) { +// if (s.dim() <= dim + 1 and s.dim() >= dim) +// simplices.emplace_back(s); +// } + + spd::debug("Enter slice diagram, line = {}, simplices.size = {}", line, simplices.size()); + + for(auto& simplex : simplices) { + Real value = line.weighted_push(simplex.position()); +// spd::debug("in slice_diagram, simplex = {}, value = {}\n", simplex, value); + simplex.set_value(value); + } + + std::sort(simplices.begin(), simplices.end(), + [](const Simplex& a, const Simplex& b) { return a.value() < b.value(); }); + std::map<Index, Index> index_map; + for(Index i = 0; i < (int) simplices.size(); i++) { + index_map[simplices[i].id()] = i; + } + + phat::boundary_matrix<> phat_matrix; + phat_matrix.set_num_cols(simplices.size()); + std::vector<Index> bd_in_slice_filtration; + for(Index i = 0; i < (int) simplices.size(); i++) { + phat_matrix.set_dim(i, simplices[i].dim()); + bd_in_slice_filtration.clear(); + //std::cout << "new col" << i << std::endl; + for(int j = 0; j < (int) simplices[i].boundary().size(); j++) { + // F[i] contains the indices of its facet wrt to the + // original filtration. We have to express it, however, + // wrt to the filtration along the slice. That is why + // we need the index_map + //std::cout << "Found " << F[i].bd[j] << ", returning " << index_map[F[i].bd[j]] << std::endl; + bd_in_slice_filtration.push_back(index_map[simplices[i].boundary()[j]]); + } + std::sort(bd_in_slice_filtration.begin(), bd_in_slice_filtration.end()); + phat_matrix.set_col(i, bd_in_slice_filtration); + } + phat::persistence_pairs phat_persistence_pairs; + phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix); + + dgm.clear(); + for(long i = 0; i < (long) phat_persistence_pairs.get_num_pairs(); i++) { + std::pair<phat::index, phat::index> new_pair = phat_persistence_pairs.get_pair(i); + Real birth = simplices.at(new_pair.first).value(); + Real death = simplices.at(new_pair.second).value(); + int dim = simplices[new_pair.first].dim(); + assert(dim + 1 == simplices[new_pair.second].dim()); + if (birth != death) { + dgm.add_point(dim, birth, death); + } + } + + spdlog::debug("Exiting slice_diagram, #dgm[0] = {}", dgm.get_diagram(0).size()); + + return dgm; + } + + Box Bifiltration::bounding_box() const + { + return bounding_box_; + } + + Real Bifiltration::minimal_coordinate() const + { + return std::min(bounding_box_.lower_left().x, bounding_box_.lower_left().y); + } + + void Bifiltration::translate(Real a) + { + bounding_box_.translate(a); + for(auto& simplex : simplices_) { + simplex.translate(a); + } + } + + Real Bifiltration::max_x() const + { + if (simplices_.empty()) + return 1; + auto me = std::max_element(simplices_.cbegin(), simplices_.cend(), + [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; }); + assert(me != simplices_.cend()); + return me->position().x; + } + + Real Bifiltration::max_y() const + { + if (simplices_.empty()) + return 1; + auto me = std::max_element(simplices_.cbegin(), simplices_.cend(), + [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; }); + assert(me != simplices_.cend()); + return me->position().y; + } + + Real Bifiltration::min_x() const + { + if (simplices_.empty()) + return 0; + auto me = std::min_element(simplices_.cbegin(), simplices_.cend(), + [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; }); + assert(me != simplices_.cend()); + return me->position().x; + } + + Real Bifiltration::min_y() const + { + if (simplices_.empty()) + return 0; + auto me = std::min_element(simplices_.cbegin(), simplices_.cend(), + [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; }); + assert(me != simplices_.cend()); + return me->position().y; + } + + void Bifiltration::add_simplex(md::Index _id, md::Point birth, int _dim, const md::Column& _bdry) + { + simplices_.emplace_back(_id, birth, _dim, _bdry); + } + + void Bifiltration::save(const std::string& filename, md::BifiltrationFormat format) + { + switch(format) { + case BifiltrationFormat::rivet: + throw std::runtime_error("Not implemented"); + break; + case BifiltrationFormat::rene: { + std::ofstream f(filename); + if (not f.good()) { + std::cerr << "Bifiltration::save: cannot open file " << filename << std::endl; + throw std::runtime_error("Cannot open file for writing "); + } + f << simplices_.size() << "\n"; + + for(const auto& s : simplices_) { + f << s.dim() << " " << s.position().x << " " << s.position().y << " "; + for(int b : s.boundary()) { + f << b << " "; + } + f << std::endl; + } + + } + break; + } + } + + + void Bifiltration::postprocess_rivet_format() + { + std::map<Column, Index> facets_to_ids; + + // fill the map + for(Index i = 0; i < (Index)simplices_.size(); ++i) { + assert(simplices_[i].id() == i); + facets_to_ids[simplices_[i].vertices_] = i; + } + +// for(const auto& s : simplices_) { +// facets_to_ids[s] = s.id(); +// } + + // main loop + for(auto& s : simplices_) { + assert(not s.vertices_.empty()); + assert(s.facet_indices_.empty()); + Column facet_indices; + for(Index i = 0; i <= s.dim(); ++i) { + Column facet; + for(Index j : s.vertices_) { + if (j != i) + facet.push_back(j); + } + auto facet_index = facets_to_ids.at(facet); + facet_indices.push_back(facet_index); + } + s.facet_indices_ = facet_indices; + } // loop over simplices + } + + std::ostream& operator<<(std::ostream& os, const Bifiltration& bif) + { + os << "Bifiltration, axes = " << bif.parameter_1_name_ << ", " << bif.parameter_2_name_ << std::endl; + for(const auto& s : bif.simplices()) { + os << s << std::endl; + } + return os; + } +} + diff --git a/matching/src/box.cpp b/matching/src/box.cpp new file mode 100644 index 0000000..c128698 --- /dev/null +++ b/matching/src/box.cpp @@ -0,0 +1,61 @@ + +#include "box.h" + +namespace md { + + std::ostream& operator<<(std::ostream& os, const Box& box) + { + os << "Box(lower_left = " << box.lower_left() << ", upper_right = " << box.upper_right() << ")"; + return os; + } + + Box get_enclosing_box(const Box& box_a, const Box& box_b) + { + Point lower_left(std::min(box_a.lower_left().x, box_b.lower_left().x), + std::min(box_a.lower_left().y, box_b.lower_left().y)); + Point upper_right(std::max(box_a.upper_right().x, box_b.upper_right().x), + std::max(box_a.upper_right().y, box_b.upper_right().y)); + return Box(lower_left, upper_right); + } + + void Box::translate(md::Real a) + { + ll.x += a; + ll.y += a; + ur.x += a; + ur.y += a; + } + + std::vector<Box> Box::refine() const + { + std::vector<Box> result; + +// 1 | 2 +// 0 | 3 + + Point new_ll = lower_left(); + Point new_ur = center(); + result.emplace_back(new_ll, new_ur); + + new_ll.y = center().y; + new_ur.y = ur.y; + result.emplace_back(new_ll, new_ur); + + new_ll = center(); + new_ur = upper_right(); + result.emplace_back(new_ll, new_ur); + + new_ll.y = ll.y; + new_ur.y = center().y; + result.emplace_back(new_ll, new_ur); + + return result; + } + + std::vector<Point> Box::corners() const + { + return {ll, Point(ll.x, ur.y), ur, Point(ur.x, ll.y)}; + }; + + +} diff --git a/matching/src/cell_with_value.cpp b/matching/src/cell_with_value.cpp new file mode 100644 index 0000000..d8fd7d4 --- /dev/null +++ b/matching/src/cell_with_value.cpp @@ -0,0 +1,247 @@ +#include <spdlog/spdlog.h> +#include <spdlog/fmt/ostr.h> + +namespace spd = spdlog; + +#include "cell_with_value.h" + +namespace md { + +#ifdef MD_DEBUG + long long int CellWithValue::max_id = 0; +#endif + + Real CellWithValue::value_at(ValuePoint vp) const + { + switch(vp) { + case ValuePoint::upper_left : + return upper_left_value_; + case ValuePoint::upper_right : + return upper_right_value_; + case ValuePoint::lower_left : + return lower_left_value_; + case ValuePoint::lower_right : + return lower_right_value_; + case ValuePoint::center: + return central_value_; + } + // to shut up compiler warning + return 1.0 / 0.0; + } + + bool CellWithValue::has_value_at(ValuePoint vp) const + { + switch(vp) { + case ValuePoint::upper_left : + return upper_left_value_ >= 0; + case ValuePoint::upper_right : + return upper_right_value_ >= 0; + case ValuePoint::lower_left : + return lower_left_value_ >= 0; + case ValuePoint::lower_right : + return lower_right_value_ >= 0; + case ValuePoint::center: + return central_value_ >= 0; + } + // to shut up compiler warning + return 1.0 / 0.0; + } + + DualPoint CellWithValue::value_point(md::ValuePoint vp) const + { + switch(vp) { + case ValuePoint::upper_left : + return dual_box().upper_left(); + case ValuePoint::upper_right : + return dual_box().upper_right(); + case ValuePoint::lower_left : + return dual_box().lower_left(); + case ValuePoint::lower_right : + return dual_box().lower_right(); + case ValuePoint::center: + return dual_box().center(); + } + // to shut up compiler warning + return DualPoint(); + } + + bool CellWithValue::has_corner_value() const + { + return has_lower_left_value() or has_lower_right_value() or has_upper_left_value() + or has_upper_right_value(); + } + + Real CellWithValue::stored_upper_bound() const + { + assert(has_max_possible_value_); + return max_possible_value_; + } + + Real CellWithValue::max_corner_value() const + { + return std::max({lower_left_value_, lower_right_value_, upper_left_value_, upper_right_value_}); + } + + Real CellWithValue::min_value() const + { + Real result = std::numeric_limits<Real>::max(); + for(auto vp : k_all_vps) { + if (not has_value_at(vp)) { + continue; + } + result = std::min(result, value_at(vp)); + } + return result; + } + + std::vector<CellWithValue> CellWithValue::get_refined_cells() const + { + std::vector<CellWithValue> result; + result.reserve(4); + for(const auto& refined_box : dual_box_.refine()) { + + CellWithValue refined_cell(refined_box, level() + 1); + +#ifdef MD_DEBUG + refined_cell.parent_ids = parent_ids; + refined_cell.parent_ids.push_back(id); + refined_cell.id = ++max_id; +#endif + + if (refined_box.lower_left() == dual_box_.lower_left()) { + // _|_ + // H|_ + + refined_cell.set_value_at(ValuePoint::lower_left, lower_left_value_); + refined_cell.set_value_at(ValuePoint::upper_right, central_value_); + + } else if (refined_box.upper_right() == dual_box_.upper_right()) { + // _|H + // _|_ + + refined_cell.set_value_at(ValuePoint::lower_left, central_value_); + refined_cell.set_value_at(ValuePoint::upper_right, upper_right_value_); + + } else if (refined_box.lower_right() == dual_box_.lower_right()) { + // _|_ + // _|H + + refined_cell.set_value_at(ValuePoint::lower_right, lower_right_value_); + refined_cell.set_value_at(ValuePoint::upper_left, central_value_); + + } else if (refined_box.upper_left() == dual_box_.upper_left()) { + + // H|_ + // _|_ + + refined_cell.set_value_at(ValuePoint::lower_right, central_value_); + refined_cell.set_value_at(ValuePoint::upper_left, upper_left_value_); + } + result.emplace_back(refined_cell); + } + return result; + } + + void CellWithValue::set_value_at(md::ValuePoint vp, md::Real new_value) + { + if (has_value_at(vp)) + spd::error("CellWithValue: trying to re-assign value!, this = {}, vp = {}", *this, vp); + + switch(vp) { + case ValuePoint::upper_left : + upper_left_value_ = new_value; + break; + case ValuePoint::upper_right : + upper_right_value_ = new_value; + break; + case ValuePoint::lower_left : + lower_left_value_ = new_value; + break; + case ValuePoint::lower_right : + lower_right_value_ = new_value; + break; + case ValuePoint::center: + central_value_ = new_value; + break; + } + + + } + + int CellWithValue::num_values() const + { + int result = 0; + for(ValuePoint vp : k_all_vps) { + result += has_value_at(vp); + } + return result; + } + + + void CellWithValue::set_max_possible_value(Real new_upper_bound) + { + assert(new_upper_bound >= central_value_); + assert(new_upper_bound >= lower_left_value_); + assert(new_upper_bound >= lower_right_value_); + assert(new_upper_bound >= upper_left_value_); + assert(new_upper_bound >= upper_right_value_); + has_max_possible_value_ = true; + max_possible_value_ = new_upper_bound; + } + + std::ostream& operator<<(std::ostream& os, const ValuePoint& vp) + { + switch(vp) { + case ValuePoint::upper_left : + os << "upper_left"; + break; + case ValuePoint::upper_right : + os << "upper_right"; + break; + case ValuePoint::lower_left : + os << "lower_left"; + break; + case ValuePoint::lower_right : + os << "lower_right"; + break; + case ValuePoint::center: + os << "center"; + break; + default: + os << "FORGOTTEN ValuePoint"; + } + return os; + } + + + + std::ostream& operator<<(std::ostream& os, const CellWithValue& cell) + { + os << "CellWithValue(box = " << cell.dual_box() << ", "; + +#ifdef MD_DEBUG + os << "id = " << cell.id; + if (not cell.parent_ids.empty()) + os << ", parent_ids = " << container_to_string(cell.parent_ids) << ", "; +#endif + + for(ValuePoint vp : k_all_vps) { + if (cell.has_value_at(vp)) { + os << "value = " << cell.value_at(vp); + os << ", at " << vp << " " << cell.value_point(vp); + } + } + + os << ", max_corner_value = "; + if (cell.has_max_possible_value()) { + os << cell.stored_upper_bound(); + } else { + os << "-"; + } + + os << ", level = " << cell.level() << ")"; + return os; + } + +} // namespace md + diff --git a/matching/src/common_util.cpp b/matching/src/common_util.cpp new file mode 100644 index 0000000..96c3388 --- /dev/null +++ b/matching/src/common_util.cpp @@ -0,0 +1,243 @@ +#include <vector> +#include <utility> +#include <cmath> +#include <ostream> +#include <limits> +#include <algorithm> + +#include <common_util.h> + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +namespace md { + + + int gcd(int a, int b) + { + assert(a != 0 or b != 0); + // make b <= a + std::tie(b, a) = std::minmax({ abs(a), abs(b) }); + if (b == 0) + return a; + while((a = a % b)) { + std::swap(a, b); + } + return b; + } + + int signum(int a) + { + if (a < 0) + return -1; + else if (a > 0) + return 1; + else + return 0; + } + + Rational reduce(Rational frac) + { + int d = gcd(frac.numerator, frac.denominator); + frac.numerator /= d; + frac.denominator /= d; + return frac; + } + + void Rational::reduce() { *this = md::reduce(*this); } + + + Rational& Rational::operator*=(const md::Rational& rhs) + { + numerator *= rhs.numerator; + denominator *= rhs.denominator; + reduce(); + return *this; + } + + Rational& Rational::operator/=(const md::Rational& rhs) + { + numerator *= rhs.denominator; + denominator *= rhs.numerator; + reduce(); + return *this; + } + + Rational& Rational::operator+=(const md::Rational& rhs) + { + numerator = numerator * rhs.denominator + denominator * rhs.numerator; + denominator *= rhs.denominator; + reduce(); + return *this; + } + + Rational& Rational::operator-=(const md::Rational& rhs) + { + numerator = numerator * rhs.denominator - denominator * rhs.numerator; + denominator *= rhs.denominator; + reduce(); + return *this; + } + + + Rational midpoint(Rational a, Rational b) + { + return reduce({a.numerator * b.denominator + a.denominator * b.numerator, 2 * a.denominator * b.denominator }); + } + + Rational operator+(Rational a, const Rational& b) + { + a += b; + return a; + } + + Rational operator-(Rational a, const Rational& b) + { + a -= b; + return a; + } + + Rational operator*(Rational a, const Rational& b) + { + a *= b; + return a; + } + + Rational operator/(Rational a, const Rational& b) + { + a /= b; + return a; + } + + bool is_less(Rational a, Rational b) + { + // compute a - b = a_1 / a_2 - b_1 / b_2 + long numer = a.numerator * b.denominator - a.denominator * b.numerator; + long denom = a.denominator * b.denominator; + assert(denom != 0); + return signum(numer) * signum(denom) < 0; + } + + bool operator==(const Rational& a, const Rational& b) + { + return std::tie(a.numerator, a.denominator) == std::tie(b.numerator, b.denominator); + } + + bool operator<(const Rational& a, const Rational& b) + { + // do not remove signum - overflow + long numer = a.numerator * b.denominator - a.denominator * b.numerator; + long denom = a.denominator * b.denominator; + assert(denom != 0); +// spdlog::debug("a = {}, b = {}, numer = {}, denom = {}, result = {}", a, b, numer, denom, signum(numer) * signum(denom) <= 0); + return signum(numer) * signum(denom) < 0; + } + + bool is_leq(Rational a, Rational b) + { + // compute a - b = a_1 / a_2 - b_1 / b_2 + long numer = a.numerator * b.denominator - a.denominator * b.numerator; + long denom = a.denominator * b.denominator; + assert(denom != 0); + return signum(numer) * signum(denom) <= 0; + } + + bool is_greater(Rational a, Rational b) + { + return not is_leq(a, b); + } + + bool is_geq(Rational a, Rational b) + { + return not is_less(a, b); + } + + Point operator+(const Point& u, const Point& v) + { + return Point(u.x + v.x, u.y + v.y); + } + + Point operator-(const Point& u, const Point& v) + { + return Point(u.x - v.x, u.y - v.y); + } + + Point least_upper_bound(const Point& u, const Point& v) + { + return Point(std::max(u.x, v.x), std::max(u.y, v.y)); + } + + Point greatest_lower_bound(const Point& u, const Point& v) + { + return Point(std::min(u.x, v.x), std::min(u.y, v.y)); + } + + Point max_point() + { + return Point(std::numeric_limits<Real>::max(), std::numeric_limits<Real>::min()); + } + + Point min_point() + { + return Point(-std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::min()); + } + + std::ostream& operator<<(std::ostream& ostr, const Point& vec) + { + ostr << "(" << vec.x << ", " << vec.y << ")"; + return ostr; + } + + Real l_infty_norm(const Point& v) + { + return std::max(std::abs(v.x), std::abs(v.y)); + } + + Real l_2_norm(const Point& v) + { + return v.norm(); + } + + Real l_2_dist(const Point& x, const Point& y) + { + return l_2_norm(x - y); + } + + Real l_infty_dist(const Point& x, const Point& y) + { + return l_infty_norm(x - y); + } + + void DiagramKeeper::add_point(int dim, md::Real birth, md::Real death) + { + data_[dim].emplace_back(birth, death); + } + + DiagramKeeper::Diagram DiagramKeeper::get_diagram(int dim) const + { + if (data_.count(dim) == 1) + return data_.at(dim); + else + return DiagramKeeper::Diagram(); + } + + // return true, if line starts with # + // or contains only spaces + bool ignore_line(const std::string& s) + { + for(auto c : s) { + if (isspace(c)) + continue; + return (c == '#'); + } + return true; + } + + + + std::ostream& operator<<(std::ostream& os, const Rational& a) + { + os << a.numerator << " / " << a.denominator; + return os; + } +} diff --git a/matching/src/dual_box.cpp b/matching/src/dual_box.cpp new file mode 100644 index 0000000..f9d2979 --- /dev/null +++ b/matching/src/dual_box.cpp @@ -0,0 +1,194 @@ +#include <random> + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +namespace spd = spdlog; + +#include "dual_box.h" + +namespace md { + + std::ostream& operator<<(std::ostream& os, const DualBox& db) + { + os << "DualBox(" << db.lower_left_ << ", " << db.upper_right_ << ")"; + return os; + } + + DualBox::DualBox(DualPoint ll, DualPoint ur) + :lower_left_(ll), upper_right_(ur) + { + } + + std::vector<DualPoint> DualBox::corners() const + { + return {lower_left_, + DualPoint(axis_type(), angle_type(), lower_left_.lambda(), upper_right_.mu()), + upper_right_, + DualPoint(axis_type(), angle_type(), upper_right_.lambda(), lower_left_.mu())}; + } + + std::vector<DualPoint> DualBox::push_change_points(const Point& p) const + { + std::vector<DualPoint> result; + result.reserve(2); + + bool is_y_type = lower_left_.is_y_type(); + bool is_flat = lower_left_.is_flat(); + + auto mu_from_lambda = [p, is_y_type, is_flat](Real lambda) { + bool is_x_type = not is_y_type, is_steep = not is_flat; + if (is_y_type and is_flat) { + return p.y - lambda * p.x; + } else if (is_y_type and is_steep) { + return p.y - p.x / lambda; + } else if (is_x_type and is_flat) { + return p.x - p.y / lambda; + } else if (is_x_type and is_steep) { + return p.x - lambda * p.y; + } + // to shut up compiler warning + return static_cast<Real>(1.0 / 0.0); + }; + + auto lambda_from_mu = [p, is_y_type, is_flat](Real mu) { + bool is_x_type = not is_y_type, is_steep = not is_flat; + if (is_y_type and is_flat) { + return (p.y - mu) / p.x; + } else if (is_y_type and is_steep) { + return p.x / (p.y - mu); + } else if (is_x_type and is_flat) { + return p.y / (p.x - mu); + } else if (is_x_type and is_steep) { + return (p.x - mu) / p.y; + } + // to shut up compiler warning + return static_cast<Real>(1.0 / 0.0); + }; + + // all inequalities below are strict: equality means it is a corner + // and critical_points() returns corners anyway + + Real mu_intersect_min = mu_from_lambda(lambda_min()); + + if (mu_min() < mu_intersect_min && mu_intersect_min < mu_max()) + result.emplace_back(axis_type(), angle_type(), lambda_min(), mu_intersect_min); + + Real mu_intersect_max = mu_from_lambda(lambda_max()); + + if (mu_max() < mu_intersect_max && mu_intersect_max < mu_max()) + result.emplace_back(axis_type(), angle_type(), lambda_max(), mu_intersect_max); + + Real lambda_intersect_min = lambda_from_mu(mu_min()); + + if (lambda_min() < lambda_intersect_min && lambda_intersect_min < lambda_max()) + result.emplace_back(axis_type(), angle_type(), lambda_intersect_min, mu_min()); + + Real lambda_intersect_max = lambda_from_mu(mu_max()); + if (lambda_min() < lambda_intersect_max && lambda_intersect_max < lambda_max()) + result.emplace_back(axis_type(), angle_type(), lambda_intersect_max, mu_max()); + + assert(result.size() <= 2); + + if (result.size() > 2) { + fmt::print("Error in push_change_points, p = {}, dual_box = {}, result = {}\n", p, *this, + container_to_string(result)); + throw std::runtime_error("push_change_points returned more than 2 points"); + } + + return result; + } + + std::vector<DualPoint> DualBox::critical_points(const Point& p) const + { + // maximal difference is attained at corners + return corners(); +// std::vector<DualPoint> result; +// result.reserve(6); +// for(auto dp : corners()) result.push_back(dp); +// for(auto dp : push_change_points(p)) result.push_back(dp); +// return result; + } + + std::vector<DualPoint> DualBox::random_points(int n) const + { + assert(n >= 0); + std::mt19937_64 gen(1); + std::vector<DualPoint> result; + result.reserve(n); + std::uniform_real_distribution<Real> mu_distr(mu_min(), mu_max()); + std::uniform_real_distribution<Real> lambda_distr(lambda_min(), lambda_max()); + for(int i = 0; i < n; ++i) { + result.emplace_back(axis_type(), angle_type(), lambda_distr(gen), mu_distr(gen)); + } + return result; + } + + bool DualBox::sanity_check() const + { + lower_left_.sanity_check(); + upper_right_.sanity_check(); + + if (lower_left_.angle_type() != upper_right_.angle_type()) + throw std::runtime_error("angle types differ"); + + if (lower_left_.axis_type() != upper_right_.axis_type()) + throw std::runtime_error("axis types differ"); + + if (lower_left_.lambda() >= upper_right_.lambda()) + throw std::runtime_error("lambda of lower_left_ greater than lambda of upper_right "); + + if (lower_left_.mu() >= upper_right_.mu()) + throw std::runtime_error("mu of lower_left_ greater than mu of upper_right "); + + return true; + } + + std::vector<DualBox> DualBox::refine() const + { + std::vector<DualBox> result; + + result.reserve(4); + + Real lambda_middle = (lower_left().lambda() + upper_right().lambda()) / 2.0; + Real mu_middle = (lower_left().mu() + upper_right().mu()) / 2.0; + + DualPoint refinement_center(axis_type(), angle_type(), lambda_middle, mu_middle); + + result.emplace_back(lower_left_, refinement_center); + + result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_middle, mu_min()), + DualPoint(axis_type(), angle_type(), lambda_max(), mu_middle)); + + result.emplace_back(refinement_center, upper_right_); + + result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_min(), mu_middle), + DualPoint(axis_type(), angle_type(), lambda_middle, mu_max())); + return result; + } + + bool DualBox::operator==(const DualBox& other) const + { + return lower_left() == other.lower_left() and + upper_right() == other.upper_right(); + } + + bool DualBox::contains(const DualPoint& dp) const + { + return dp.angle_type() == angle_type() and dp.axis_type() == axis_type() and + mu_max() >= dp.mu() and + mu_min() <= dp.mu() and + lambda_min() <= dp.lambda() and + lambda_max() >= dp.lambda(); + } + + DualPoint DualBox::lower_right() const + { + return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_max(), mu_min()); + } + + DualPoint DualBox::upper_left() const + { + return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_min(), mu_max()); + } +} diff --git a/matching/src/dual_point.cpp b/matching/src/dual_point.cpp new file mode 100644 index 0000000..1c00b58 --- /dev/null +++ b/matching/src/dual_point.cpp @@ -0,0 +1,282 @@ +#include <tuple> + +#include "dual_point.h" + +namespace md { + + std::ostream& operator<<(std::ostream& os, const AxisType& at) + { + if (at == AxisType::x_type) + os << "x-type"; + else + os << "y-type"; + return os; + } + + std::ostream& operator<<(std::ostream& os, const AngleType& at) + { + if (at == AngleType::flat) + os << "flat"; + else + os << "steep"; + return os; + } + + std::ostream& operator<<(std::ostream& os, const DualPoint& dp) + { + os << "Line(" << dp.axis_type() << ", "; + os << dp.angle_type() << ", "; + os << dp.lambda() << ", "; + os << dp.mu() << ", equation: "; + if (not dp.is_vertical()) { + os << "y = " << dp.y_slope() << " x + " << dp.y_intercept(); + } else { + os << "x = " << dp.x_intercept(); + } + os << " )"; + return os; + } + + bool DualPoint::operator<(const DualPoint& rhs) const + { + return std::tie(axis_type_, angle_type_, lambda_, mu_) + < std::tie(rhs.axis_type_, rhs.angle_type_, rhs.lambda_, rhs.mu_); + } + + DualPoint::DualPoint(AxisType axis_type, AngleType angle_type, Real lambda, Real mu) + : + axis_type_(axis_type), + angle_type_(angle_type), + lambda_(lambda), + mu_(mu) + { + assert(sanity_check()); + } + + bool DualPoint::sanity_check() const + { + if (lambda_ < 0.0) + throw std::runtime_error("Invalid line, negative lambda"); + if (lambda_ > 1.0) + throw std::runtime_error("Invalid line, lambda > 1"); + if (mu_ < 0.0) + throw std::runtime_error("Invalid line, negative mu"); + return true; + } + + Real DualPoint::gamma() const + { + if (is_steep()) + return atan(Real(1.0) / lambda_); + else + return atan(lambda_); + } + + DualPoint midpoint(DualPoint x, DualPoint y) + { + assert(x.angle_type() == y.angle_type() and x.axis_type() == y.axis_type()); + Real lambda_mid = (x.lambda() + y.lambda()) / 2; + Real mu_mid = (x.mu() + y.mu()) / 2; + return DualPoint(x.axis_type(), x.angle_type(), lambda_mid, mu_mid); + + } + + // return k in the line equation y = kx + b + Real DualPoint::y_slope() const + { + if (is_flat()) + return lambda(); + else + return Real(1.0) / lambda(); + } + + // return k in the line equation x = ky + b + Real DualPoint::x_slope() const + { + if (is_flat()) + return Real(1.0) / lambda(); + else + return lambda(); + } + + // return b in the line equation y = kx + b + Real DualPoint::y_intercept() const + { + if (is_y_type()) { + return mu(); + } else { + // x = x_slope * y + mu = x_slope * (y + mu / x_slope) + // x-intercept is -mu/x_slope = -mu * y_slope + return -mu() * y_slope(); + } + } + + // return k in the line equation x = ky + b + Real DualPoint::x_intercept() const + { + if (is_x_type()) { + return mu(); + } else { + // y = y_slope * x + mu = y_slope (x + mu / y_slope) + // x_intercept is -mu/y_slope = -mu * x_slope + return -mu() * x_slope(); + } + } + + Real DualPoint::x_from_y(Real y) const + { + if (is_horizontal()) + throw std::runtime_error("x_from_y called on horizontal line"); + else + return x_slope() * y + x_intercept(); + } + + Real DualPoint::y_from_x(Real x) const + { + if (is_vertical()) + throw std::runtime_error("x_from_y called on horizontal line"); + else + return y_slope() * x + y_intercept(); + } + + bool DualPoint::is_horizontal() const + { + return is_flat() and lambda() == 0; + } + + bool DualPoint::is_vertical() const + { + return is_steep() and lambda() == 0; + } + + bool DualPoint::contains(Point p) const + { + if (is_vertical()) + return p.x == x_from_y(p.y); + else + return p.y == y_from_x(p.x); + } + + bool DualPoint::goes_below(Point p) const + { + if (is_vertical()) + return p.x <= x_from_y(p.y); + else + return p.y >= y_from_x(p.x); + } + + bool DualPoint::goes_above(Point p) const + { + if (is_vertical()) + return p.x >= x_from_y(p.y); + else + return p.y <= y_from_x(p.x); + } + + Point DualPoint::push(Point p) const + { + Point result; + // if line is below p, we push horizontally + bool horizontal_push = goes_below(p); + if (is_x_type()) { + if (is_flat()) { + if (horizontal_push) { + result.x = p.y / lambda() + mu(); + result.y = p.y; + } else { + // vertical push + result.x = p.x; + result.y = lambda() * (p.x - mu()); + } + } else { + // steep + if (horizontal_push) { + result.x = lambda() * p.y + mu(); + result.y = p.y; + } else { + // vertical push + result.x = p.x; + result.y = (p.x - mu()) / lambda(); + } + } + } else { + // y-type + if (is_flat()) { + if (horizontal_push) { + result.x = (p.y - mu()) / lambda(); + result.y = p.y; + } else { + // vertical push + result.x = p.x; + result.y = lambda() * p.x + mu(); + } + } else { + // steep + if (horizontal_push) { + result.x = (p.y - mu()) * lambda(); + result.y = p.y; + } else { + // vertical push + result.x = p.x; + result.y = p.x / lambda() + mu(); + } + } + } + return result; + } + + Real DualPoint::weighted_push(Point p) const + { + // if line is below p, we push horizontally + bool horizontal_push = goes_below(p); + if (is_x_type()) { + if (is_flat()) { + if (horizontal_push) { + return p.y; + } else { + // vertical push + return lambda() * (p.x - mu()); + } + } else { + // steep + if (horizontal_push) { + return lambda() * p.y; + } else { + // vertical push + return (p.x - mu()); + } + } + } else { + // y-type + if (is_flat()) { + if (horizontal_push) { + return p.y - mu(); + } else { + // vertical push + return lambda() * p.x; + } + } else { + // steep + if (horizontal_push) { + return lambda() * (p.y - mu()); + } else { + // vertical push + return p.x; + } + } + } + } + + bool DualPoint::operator==(const DualPoint& other) const + { + return axis_type() == other.axis_type() and + angle_type() == other.angle_type() and + mu() == other.mu() and + lambda() == other.lambda(); + } + + Real DualPoint::weight() const + { + return lambda_ / sqrt(1 + lambda_ * lambda_); + } +} // namespace md diff --git a/matching/src/main.cpp b/matching/src/main.cpp new file mode 100644 index 0000000..34c4240 --- /dev/null +++ b/matching/src/main.cpp @@ -0,0 +1,367 @@ +#include "common_defs.h"
+
+#include <iostream>
+#include <string>
+#include <cassert>
+#include <experimental/filesystem>
+
+#ifdef EXPERIMENTAL_TIMING
+#include <chrono>
+#endif
+
+#include "opts/opts.h"
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/ostr.h"
+
+//#include "persistence_module.h"
+#include "bifiltration.h"
+#include "box.h"
+#include "matching_distance.h"
+
+using namespace md;
+
+namespace fs = std::experimental::filesystem;
+
+void print_heat_map(const md::HeatMaps& hms, std::string fname, const CalculationParams& params)
+{
+#ifdef PRINT_HEAT_MAP
+ spd::debug("Entered print_heat_map");
+ std::set<Real> mu_vals, lambda_vals;
+ auto hm_iter = hms.end();
+ --hm_iter;
+ int max_level = hm_iter->first;
+
+ int level_cardinality = 4;
+ for(int i = 0; i < params.initialization_depth; ++i) {
+ level_cardinality *= 4;
+ }
+ for(int i = params.initialization_depth + 1; i <= max_level; ++i) {
+ spd::debug("hms.at({}).size = {}, must be {}", i, hms.at(i).size(), level_cardinality);
+ assert(static_cast<decltype(level_cardinality)>(hms.at(i).size()) == level_cardinality);
+ level_cardinality *= 4;
+ }
+
+ std::map<std::pair<Real, Real>, Real> hm_x_flat, hm_x_steep, hm_y_flat, hm_y_steep;
+
+ for(const auto& dual_point_value_pair : hms.at(max_level)) {
+ const DualPoint& k = dual_point_value_pair.first;
+ spd::debug("HM DP: {}", k);
+ mu_vals.insert(k.mu());
+ lambda_vals.insert(k.lambda());
+ }
+
+ std::vector<Real> lambda_vals_vec(lambda_vals.begin(), lambda_vals.end());
+ std::vector<Real> mu_vals_vec(mu_vals.begin(), mu_vals.end());
+
+ std::ofstream ofs {fname};
+ if (not ofs.good()) {
+ std::cerr << "Cannot write heat map to file " << fname << std::endl;
+ throw std::runtime_error("Cannot open file for writing heat map");
+ }
+
+ std::vector<std::vector<Real>> heatmap_to_print(2 * mu_vals_vec.size(),
+ std::vector<Real>(2 * lambda_vals_vec.size(), 0.0));
+
+ for(auto axis_type : {AxisType::x_type, AxisType::y_type}) {
+ bool is_x_type = axis_type == AxisType::x_type;
+ for(auto angle_type : {AngleType::flat, AngleType::steep}) {
+ bool is_flat = angle_type == AngleType::flat;
+
+ int mu_idx_begin, mu_idx_end;
+
+ if (is_x_type) {
+ mu_idx_begin = mu_vals_vec.size() - 1;
+ mu_idx_end = -1;
+ } else {
+ mu_idx_begin = 0;
+ mu_idx_end = mu_vals_vec.size();
+ }
+
+ int lambda_idx_begin, lambda_idx_end;
+
+ if (is_flat) {
+ lambda_idx_begin = 0;
+ lambda_idx_end = lambda_vals_vec.size();
+ } else {
+ lambda_idx_begin = lambda_vals_vec.size() - 1;
+ lambda_idx_end = -1;
+ }
+
+ int mu_idx_final = is_x_type ? 0 : mu_vals_vec.size();
+
+ for(int mu_idx = mu_idx_begin; mu_idx != mu_idx_end; (mu_idx_begin < mu_idx_end) ? mu_idx++ : mu_idx--) {
+ Real mu = mu_vals_vec.at(mu_idx);
+
+ if (mu == 0.0 and axis_type == AxisType::x_type)
+ continue;
+
+ int lambda_idx_final = is_flat ? 0 : lambda_vals_vec.size();
+
+ for(int lambda_idx = lambda_idx_begin;
+ lambda_idx != lambda_idx_end;
+ (lambda_idx_begin < lambda_idx_end) ? lambda_idx++ : lambda_idx--) {
+
+ Real lambda = lambda_vals_vec.at(lambda_idx);
+
+ if (lambda == 0.0 and angle_type == AngleType::flat)
+ continue;
+
+ DualPoint dp(axis_type, angle_type, lambda, mu);
+ Real dist_value = hms.at(max_level).at(dp);
+
+ heatmap_to_print.at(mu_idx_final).at(lambda_idx_final) = dist_value;
+
+// fmt::print("HM, dp = {}, mu_idx_final = {}, lambda_idx_final = {}, value = {}\n", dp, mu_idx_final,
+// lambda_idx_final, dist_value);
+
+ lambda_idx_final++;
+ }
+ mu_idx_final++;
+ }
+ }
+ }
+
+ for(size_t m_idx = 0; m_idx < heatmap_to_print.size(); ++m_idx) {
+ for(size_t l_idx = 0; l_idx < heatmap_to_print.at(m_idx).size(); ++l_idx) {
+ ofs << heatmap_to_print.at(m_idx).at(l_idx) << " ";
+ }
+ ofs << std::endl;
+ }
+
+ ofs.close();
+ spd::debug("Exit print_heat_map");
+#endif
+}
+
+int main(int argc, char** argv)
+{
+ spdlog::set_level(spdlog::level::info);
+ //spdlog::set_pattern("[%L] %v");
+
+ using opts::Option;
+ using opts::PosOption;
+ opts::Options ops;
+
+ bool help = false;
+ bool heatmap_only = false;
+ bool no_stop_asap = false;
+ CalculationParams params;
+
+ std::string bounds_list_str = "local_combined";
+ std::string traverse_list_str = "BFS";
+
+ ops >> Option('m', "max-iterations", params.max_depth, "maximal number of iterations (refinements)")
+ >> Option('e', "max-error", params.delta, "error threshold")
+ >> Option('d', "dim", params.dim, "dim")
+ >> Option('i', "initial-depth", params.initialization_depth, "initialization depth")
+ >> Option("no-stop-asap", no_stop_asap,
+ "don't stop looping over points, if cell cannot be pruned (asap is on by default)")
+ >> Option("bounds", bounds_list_str, "bounds to use, separated by ,")
+ >> Option("traverse", traverse_list_str, "traverse to use, separated by ,")
+#ifdef PRINT_HEAT_MAP
+ >> Option("heatmap-only", heatmap_only, "only save heatmap (bruteforce)")
+#endif
+ >> Option('h', "help", help, "show help message");
+
+ std::string fname_a;
+ std::string fname_b;
+
+ if (!ops.parse(argc, argv) || help || !(ops >> PosOption(fname_a) >> PosOption(fname_b))) {
+ std::cerr << "Usage: " << argv[0] << "bifiltration-file-1 bifiltration-file-2\n" << ops << std::endl;
+ return 1;
+ }
+
+ params.stop_asap = not no_stop_asap;
+
+ auto bounds_list = split_by_delim(bounds_list_str, ',');
+ auto traverse_list = split_by_delim(traverse_list_str, ',');
+
+ Bifiltration bif_a(fname_a, BifiltrationFormat::rene);
+ Bifiltration bif_b(fname_b, BifiltrationFormat::rene);
+
+ bif_a.sanity_check();
+ bif_b.sanity_check();
+
+ spd::info("Read bifiltrations {} {}", fname_a, fname_b);
+
+ std::vector<BoundStrategy> bound_strategies;
+ std::vector<TraverseStrategy> traverse_strategies;
+
+ for(std::string s : bounds_list) {
+ bound_strategies.push_back(bs_from_string(s));
+ }
+
+ for(std::string s : traverse_list) {
+ traverse_strategies.push_back(ts_from_string(s));
+ }
+
+ for(auto bs : bound_strategies) {
+ for(auto ts : traverse_strategies) {
+ spd::info("Will test combination {} {}", bs, ts);
+ }
+ }
+
+#ifdef EXPERIMENTAL_TIMING
+ struct ExperimentResult {
+ CalculationParams params {CalculationParams()};
+ int n_hera_calls {0};
+ double total_milliseconds_elapsed {0};
+ double distance {0};
+ double actual_error {std::numeric_limits<double>::max()};
+ int actual_max_depth {0};
+
+ int x_wins {0};
+ int y_wins {0};
+ int ad_wins {0};
+
+ int seconds_elapsed() const
+ {
+ return static_cast<int>(total_milliseconds_elapsed / 1000);
+ }
+
+ double savings_ratio_old() const
+ {
+ long int max_possible_calls = 0;
+ long int calls_on_level = 4;
+ for(int i = 0; i <= actual_max_depth; ++i) {
+ max_possible_calls += calls_on_level;
+ calls_on_level *= 4;
+ }
+ return static_cast<double>(n_hera_calls) / static_cast<double>(max_possible_calls);
+ }
+
+ double savings_ratio() const
+ {
+ return static_cast<double>(n_hera_calls) / calls_on_actual_max_depth();
+ }
+
+ long long int calls_on_actual_max_depth() const
+ {
+ long long int result = 1;
+ for(int i = 0; i < actual_max_depth; ++i) {
+ result *= 4;
+ }
+ return result;
+ }
+
+ ExperimentResult() { }
+
+ ExperimentResult(CalculationParams p, int nhc, double tme, double d)
+ :
+ params(p), n_hera_calls(nhc), total_milliseconds_elapsed(tme), distance(d) { }
+ };
+
+ const int n_repetitions = 1;
+
+ if (heatmap_only) {
+ bound_strategies.clear();
+ bound_strategies.push_back(BoundStrategy::bruteforce);
+ traverse_strategies.clear();
+ traverse_strategies.push_back(TraverseStrategy::breadth_first);
+ }
+
+ std::map<std::tuple<BoundStrategy, TraverseStrategy>, ExperimentResult> results;
+ for(BoundStrategy bound_strategy : bound_strategies) {
+ for(TraverseStrategy traverse_strategy : traverse_strategies) {
+ CalculationParams params_experiment;
+ params_experiment.bound_strategy = bound_strategy;
+ params_experiment.traverse_strategy = traverse_strategy;
+ params_experiment.max_depth = params.max_depth;
+ params_experiment.initialization_depth = params.initialization_depth;
+ params_experiment.delta = params.delta;
+ params_experiment.dim = params.dim;
+ params_experiment.hera_epsilon = params.hera_epsilon;
+ params_experiment.stop_asap = params.stop_asap;
+
+ if (traverse_strategy == TraverseStrategy::depth_first and bound_strategy == BoundStrategy::bruteforce)
+ continue;
+
+ // if bruteforce, clamp max iterations number to 7,
+ // save user-provided max_iters in user_max_iters and restore it later.
+ // remember: params is passed by reference to return real relative error and heat map
+
+ int user_max_iters = params.max_depth;
+ if (bound_strategy == BoundStrategy::bruteforce and not heatmap_only) {
+ params_experiment.max_depth = std::min(7, params.max_depth);
+ }
+ double total_milliseconds_elapsed = 0;
+ int total_n_hera_calls = 0;
+ Real dist;
+ for(int i = 0; i < n_repetitions; ++i) {
+ spd::debug("Processing bound_strategy {}, traverse_strategy {}, iteration = {}", bound_strategy,
+ traverse_strategy, i);
+ auto t1 = std::chrono::high_resolution_clock().now();
+ dist = matching_distance(bif_a, bif_b, params_experiment);
+ auto t2 = std::chrono::high_resolution_clock().now();
+ total_milliseconds_elapsed += std::chrono::duration_cast<std::chrono::milliseconds>(
+ t2 - t1).count();
+ total_n_hera_calls += params_experiment.n_hera_calls;
+ }
+
+ auto key = std::make_tuple(bound_strategy, traverse_strategy);
+ results[key].params = params_experiment;
+ results[key].n_hera_calls = total_n_hera_calls / n_repetitions;
+ results[key].total_milliseconds_elapsed = total_milliseconds_elapsed / n_repetitions;
+ results[key].distance = dist;
+ results[key].actual_error = params_experiment.actual_error;
+ results[key].actual_max_depth = params_experiment.actual_max_depth;
+
+ spd::info(
+ "Done (bound = {}, traverse = {}), n_hera_calls = {}, time = {} sec, d = {}, error = {}, savings = {}, max_depth = {}",
+ bound_strategy, traverse_strategy, results[key].n_hera_calls, results[key].seconds_elapsed(),
+ dist,
+ params_experiment.actual_error, results[key].savings_ratio(), results[key].actual_max_depth);
+
+ if (bound_strategy == BoundStrategy::bruteforce) { params_experiment.max_depth = user_max_iters; }
+
+#ifdef PRINT_HEAT_MAP
+ if (bound_strategy == BoundStrategy::bruteforce) {
+ fs::path fname_a_path {fname_a.c_str()};
+ fs::path fname_b_path {fname_b.c_str()};
+ fs::path fname_a_wo = fname_a_path.filename();
+ fs::path fname_b_wo = fname_b_path.filename();
+ std::string heat_map_fname = fmt::format("{0}_{1}_dim_{2}_weighted_values_xyp.txt", fname_a_wo.string(),
+ fname_b_wo.string(), params_experiment.dim);
+ fs::path path_hm = fname_a_path.replace_filename(fs::path(heat_map_fname.c_str()));
+ spd::debug("Saving heatmap to {}", heat_map_fname);
+ print_heat_map(params_experiment.heat_maps, path_hm.string(), params);
+ }
+#endif
+ spd::debug("Finished processing bound_strategy {}", bound_strategy);
+ }
+ }
+
+// std::cout << "File_1;File_2;Boundstrategy;TraverseStrategy;InitalDepth;NHeraCalls;SavingsRatio;Time;Distance;Error;PushStrategy;MaxDepth;CallsOnMaxDepth;Delta;Dimension" << std::endl;
+ for(auto bs : bound_strategies) {
+ for(auto ts : traverse_strategies) {
+ auto key = std::make_tuple(bs, ts);
+
+ fs::path fname_a_path {fname_a.c_str()};
+ fs::path fname_b_path {fname_b.c_str()};
+ fs::path fname_a_wo = fname_a_path.filename();
+ fs::path fname_b_wo = fname_b_path.filename();
+
+ std::cout << fname_a_wo.string() << ";" << fname_b_wo.string() << ";" << bs << ";" << ts << ";";
+ std::cout << results[key].params.initialization_depth << ";";
+ std::cout << results[key].n_hera_calls << ";"
+ << results[key].savings_ratio() << ";"
+ << results[key].total_milliseconds_elapsed << ";"
+ << results[key].distance << ";"
+ << results[key].actual_error << ";"
+ << "xyp" << ";"
+ << results[key].actual_max_depth << ";"
+ << results[key].calls_on_actual_max_depth() << ";"
+ << params.delta << ";"
+ << params.dim
+ << std::endl;
+ }
+ }
+#else
+ params.bound_strategy = bound_strategies.back();
+ params.traverse_strategy = traverse_strategies.back();
+
+ Real dist = matching_distance(bif_a, bif_b, params);
+ std::cout << dist << std::endl;
+#endif
+ return 0;
+}
diff --git a/matching/src/matching_distance.cpp b/matching/src/matching_distance.cpp new file mode 100644 index 0000000..ac96ba2 --- /dev/null +++ b/matching/src/matching_distance.cpp @@ -0,0 +1,907 @@ +#include <chrono> +#include <tuple> +#include <algorithm> + +#include "common_defs.h" + +#include "spdlog/fmt/ostr.h" +#include "matching_distance.h" + +namespace md { + + template<class K, class V> + void print_map(const std::map<K, V>& dic) + { + for(const auto kv : dic) { + fmt::print("{} -> {}\n", kv.first, kv.second); + } + } + + void DistanceCalculator::check_upper_bound(const CellWithValue& dual_cell, int dim) const + { + spd::debug("Enter check_get_max_delta_on_cell"); + const int n_samples_lambda = 100; + const int n_samples_mu = 100; + DualBox db = dual_cell.dual_box(); + Real min_lambda = db.lambda_min(); + Real max_lambda = db.lambda_max(); + Real min_mu = db.mu_min(); + Real max_mu = db.mu_max(); + + Real h_lambda = (max_lambda - min_lambda) / n_samples_lambda; + Real h_mu = (max_mu - min_mu) / n_samples_mu; + for(int i = 1; i < n_samples_lambda; ++i) { + for(int j = 1; j < n_samples_mu; ++j) { + Real lambda = min_lambda + i * h_lambda; + Real mu = min_mu + j * h_mu; + DualPoint l(db.axis_type(), db.angle_type(), lambda, mu); + Real other_result = distance_on_line_const(dim, l); + Real diff = fabs(dual_cell.stored_upper_bound() - other_result); + if (other_result > dual_cell.stored_upper_bound()) { + spd::error( + "in check_upper_bound, upper_bound = {}, other_result = {}, diff = {}, dim = {}\ndual_cell = {}", + dual_cell.stored_upper_bound(), other_result, diff, dim, dual_cell); + throw std::runtime_error("Wrong delta estimate"); + } + } + } + spd::debug("Exit check_get_max_delta_on_cell"); + } + + // for all lines l, l' inside dual box, + // find the upper bound on the difference of weighted pushes of p + Real + DistanceCalculator::get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint vp, + const Point& p) const + { + assert(p.x >= 0 && p.y >= 0); + +#ifdef MD_DEBUG + std::vector<long long int> debug_ids = {3, 13, 54, 218, 350, 382, 484, 795, 2040, 8415, 44076}; + bool debug = false; // std::find(debug_ids.begin(), debug_ids.end(), dual_cell.id) != debug_ids.end(); +#endif + DualPoint line = dual_cell.value_point(vp); + const Real base_value = line.weighted_push(p); + + spd::debug("Enter get_max_displacement_single_point, p = {},\ndual_cell = {},\nline = {}, base_value = {}\n", p, + dual_cell, line, base_value); + + Real result = 0.0; + for(DualPoint dp : dual_cell.dual_box().critical_points(p)) { + Real dp_value = dp.weighted_push(p); + spd::debug( + "In get_max_displacement_single_point, p = {}, critical dp = {},\ndp_value = {}, diff = {},\ndual_cell = {}\n", + p, dp, dp_value, fabs(base_value - dp_value), dual_cell); + result = std::max(result, fabs(base_value - dp_value)); + } + +#ifdef MD_DO_FULL_CHECK + DualBox db = dual_cell.dual_box(); + std::uniform_real_distribution<Real> dlambda(db.lambda_min(), db.lambda_max()); + std::uniform_real_distribution<Real> dmu(db.mu_min(), db.mu_max()); + std::mt19937 gen(1); + for(int i = 0; i < 1000; ++i) { + Real lambda = dlambda(gen); + Real mu = dmu(gen); + DualPoint dp_random { db.axis_type(), db.angle_type(), lambda, mu }; + Real dp_value = dp_random.weighted_push(p); + if (fabs(base_value - dp_value) > result) { + spd::error("in get_max_displacement_single_point, p = {}, vp = {}\ndb = {}\nresult = {}, base_value = {}, dp_value = {}, dp_random = {}", + p, vp, db, result, base_value, dp_value, dp_random); + throw std::runtime_error("error in get_max_displacement_single_value"); + } + } +#endif + + return result; + } + + DistanceCalculator::CellValueVector DistanceCalculator::get_initial_dual_grid(Real& lower_bound) + { + CellValueVector result = get_refined_grid(params_.initialization_depth, false, true); + + lower_bound = -1.0; + for(const auto& dc : result) { + lower_bound = std::max(lower_bound, dc.max_corner_value()); + } + + assert(lower_bound >= 0); + + for(auto& dual_cell : result) { + Real good_enough_ub = get_good_enough_upper_bound(lower_bound); + Real max_value_on_cell = get_upper_bound(dual_cell, params_.dim, good_enough_ub); + dual_cell.set_max_possible_value(max_value_on_cell); + +#ifdef MD_DO_FULL_CHECK + check_upper_bound(dual_cell, params_.dim); +#endif + + spd::debug("DEBUG INIT: added cell {}", dual_cell); + } + + + + return result; + } + + DistanceCalculator::CellValueVector + DistanceCalculator::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last) + { + const Real y_max = std::max(module_a_.max_y(), module_b_.max_y()); + const Real x_max = std::max(module_a_.max_x(), module_b_.max_x()); + + const Real lambda_min = 0; + const Real lambda_max = 1; + + const Real mu_min = 0; + + DualBox x_flat(DualPoint(AxisType::x_type, AngleType::flat, lambda_min, mu_min), + DualPoint(AxisType::x_type, AngleType::flat, lambda_max, x_max)); + + DualBox x_steep(DualPoint(AxisType::x_type, AngleType::steep, lambda_min, mu_min), + DualPoint(AxisType::x_type, AngleType::steep, lambda_max, x_max)); + + DualBox y_flat(DualPoint(AxisType::y_type, AngleType::flat, lambda_min, mu_min), + DualPoint(AxisType::y_type, AngleType::flat, lambda_max, y_max)); + + DualBox y_steep(DualPoint(AxisType::y_type, AngleType::steep, lambda_min, mu_min), + DualPoint(AxisType::y_type, AngleType::steep, lambda_max, y_max)); + + CellWithValue x_flat_cell(x_flat, 0); + CellWithValue x_steep_cell(x_steep, 0); + CellWithValue y_flat_cell(y_flat, 0); + CellWithValue y_steep_cell(y_steep, 0); + + if (init_depth == 0) { + DualPoint diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0); + + Real diagonal_value = distance_on_line(params_.dim, diagonal_x_flat); + n_hera_calls_per_level_[0]++; + + x_flat_cell.set_value_at(ValuePoint::lower_right, diagonal_value); + y_flat_cell.set_value_at(ValuePoint::lower_right, diagonal_value); + x_steep_cell.set_value_at(ValuePoint::lower_right, diagonal_value); + y_steep_cell.set_value_at(ValuePoint::lower_right, diagonal_value); + } + +#ifdef MD_DEBUG + x_flat_cell.id = 1; + x_steep_cell.id = 2; + y_flat_cell.id = 3; + y_steep_cell.id = 4; + CellWithValue::max_id = 4; +#endif + + CellValueVector result {x_flat_cell, x_steep_cell, y_flat_cell, y_steep_cell}; + + if (init_depth == 0) { + return result; + } + + CellValueVector refined_result; + + for(int i = 1; i <= init_depth; ++i) { + refined_result.clear(); + for(const auto& dual_cell : result) { + for(auto refined_cell : dual_cell.get_refined_cells()) { + // we calculate for init_dept - 1, not init_depth, + // because we want the cells to have value at a corner + if ((i == init_depth - 1 and calculate_on_last) or calculate_on_intermediate) + set_cell_central_value(refined_cell, params_.dim); + refined_result.push_back(refined_cell); + } + } + result = std::move(refined_result); + } + return result; + } + + DistanceCalculator::DistanceCalculator(const DiagramProvider& a, + const DiagramProvider& b, + CalculationParams& params) + : + module_a_(a), + module_b_(b), + params_(params), + maximal_dim_(std::max(a.maximal_dim(), b.maximal_dim())), + distances_(1 + std::max(a.maximal_dim(), b.maximal_dim()), Real(-1)) + { + // make all coordinates non-negative + auto min_coord = std::min(module_a_.minimal_coordinate(), + module_b_.minimal_coordinate()); + if (min_coord < 0) { + module_a_.translate(-min_coord); + module_b_.translate(-min_coord); + } + + assert(std::min({module_a_.min_x(), module_b_.min_x(), module_a_.min_y(), + module_b_.min_y()}) >= 0); + + spd::info("DistanceCalculator constructed, module_a: max_x = {}, max_y = {}, module_b: max_x = {}, max_y = {}", + module_a_.max_x(), module_a_.max_y(), module_b_.max_x(), module_b_.max_y()); + } + + void DistanceCalculator::clear_cache() + { + distances_ = std::vector<Real>(maximal_dim_, Real(-1)); + } + + Real DistanceCalculator::get_max_x(int module) const + { + return (module == 0) ? module_a_.max_x() : module_b_.max_x(); + } + + Real DistanceCalculator::get_max_y(int module) const + { + return (module == 0) ? module_a_.max_y() : module_b_.max_y(); + } + + Real + DistanceCalculator::get_local_refined_bound(const md::DualBox& dual_box) const + { + return get_local_refined_bound(0, dual_box) + get_local_refined_bound(1, dual_box); + } + + Real + DistanceCalculator::get_local_refined_bound(int module, const md::DualBox& dual_box) const + { + spd::debug("Enter get_local_refined_bound, dual_box = {}", dual_box); + Real d_lambda = dual_box.lambda_max() - dual_box.lambda_min(); + Real d_mu = dual_box.mu_max() - dual_box.mu_min(); + Real result; + if (dual_box.axis_type() == AxisType::x_type) { + if (dual_box.is_flat()) { + result = dual_box.lambda_max() * d_mu + (get_max_x(module) - dual_box.mu_min()) * d_lambda; + } else { + result = d_mu + get_max_y(module) * d_lambda; + } + } else { + // y-type + if (dual_box.is_flat()) { + result = d_mu + get_max_x(module) * d_lambda; + } else { + // steep + result = dual_box.lambda_max() * d_mu + (get_max_y(module) - dual_box.mu_min()) * d_lambda; + } + } + return result; + } + + Real DistanceCalculator::get_local_dual_bound(int module, const md::DualBox& dual_box) const + { + Real dlambda = dual_box.lambda_max() - dual_box.lambda_min(); + Real dmu = dual_box.mu_max() - dual_box.mu_min(); + Real C = std::max(get_max_x(module), get_max_y(module)); + + //return 2 * (C * dlambda + dmu); + + // additional factor of 2 because we mimic Cerri's paper + // where subdivision is on angle spaces, + // and tangent/cotangent is 2-Lipschitz + if (dual_box.is_flat()) { + return get_max_x(module) * dlambda + dmu; + } else { + return get_max_y(module) * dlambda + dmu; + } + } + + Real DistanceCalculator::get_local_dual_bound(const md::DualBox& dual_box) const + { + return get_local_dual_bound(0, dual_box) + get_local_dual_bound(1, dual_box); + } + + Real DistanceCalculator::get_upper_bound(const CellWithValue& dual_cell, int dim, Real good_enough_ub) const + { + assert(good_enough_ub >= 0); + + switch(params_.bound_strategy) { + case BoundStrategy::bruteforce: + return std::numeric_limits<Real>::max(); + + case BoundStrategy::local_dual_bound: + return dual_cell.min_value() + get_local_dual_bound(dual_cell.dual_box()); + + case BoundStrategy::local_dual_bound_refined: + return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); + + case BoundStrategy::local_combined: { + Real cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); + if (cheap_upper_bound < good_enough_ub) { + return cheap_upper_bound; + } else { + [[fallthrough]]; + } + } + + case BoundStrategy::local_dual_bound_for_each_point: { + Real result = std::numeric_limits<Real>::max(); + for(ValuePoint vp : k_corner_vps) { + if (not dual_cell.has_value_at(vp)) { + continue; + } + + Real base_value = dual_cell.value_at(vp); + Real bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, dim, good_enough_ub); + + if (params_.stop_asap and bound_dgm_a + base_value >= good_enough_ub) { + // we want to return a valid upper bound, not just something that will prevent discarding the cell + // and we don't want to compute pushes for points in second bifiltration. + // so just return a constant time bound + return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box()); + } + + Real bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1, dim, + std::max(Real(0), good_enough_ub - bound_dgm_a)); + + result = std::min(result, base_value + bound_dgm_a + bound_dgm_b); + +#ifdef MD_DEBUG + spd::debug("In get_upper_bound, cell = {}", dual_cell); + spd::debug("In get_upper_bound, vp = {}, base_value = {}, bound_dgm_a = {}, bound_dgm_b = {}, result = {}", vp, base_value, bound_dgm_a, bound_dgm_b, result); +#endif + + if (params_.stop_asap and result < good_enough_ub) { + break; + } + } + return result; + } + } + // to suppress compiler warning + return std::numeric_limits<Real>::max(); + } + + // find maximal displacement of weighted points of m for all lines in dual_box + Real + DistanceCalculator::get_single_dgm_bound(const CellWithValue& dual_cell, + ValuePoint vp, + int module, + int dim, + [[maybe_unused]] Real good_enough_value) const + { + Real result = 0; + Point max_point; + + spd::debug("Enter get_single_dgm_bound, module = {}, dual_cell = {}, vp = {}, good_enough_value = {}, stop_asap = {}\n", module, dual_cell, vp, good_enough_value, params_.stop_asap); + + const DiagramProvider& m = (module == 0) ? module_a_ : module_b_; + for(const auto& simplex : m.simplices()) { + spd::debug("in get_single_dgm_bound, simplex = {}\n", simplex); + if (dim != simplex.dim() and dim + 1 != simplex.dim()) + continue; + + Real x = get_max_displacement_single_point(dual_cell, vp, simplex.position()); + + spd::debug("In get_single_dgm_bound, point = {}, displacement = {}", simplex.position(), x); + + if (x > result) { + result = x; + max_point = simplex.position(); + spd::debug("In get_single_dgm_bound, point = {}, result now = displacement = {}", simplex.position(), x); + } + + if (params_.stop_asap and result > good_enough_value) { + // we want to return a valid upper bound, + // now we just see it is worse than we need, but it may be even more + // just return a valid upper bound + spd::debug("result {} > good_enough_value {}, exit and return refined bound {}", result, good_enough_value, get_local_refined_bound(dual_cell.dual_box())); + result = get_local_refined_bound(dual_cell.dual_box()); + break; + } + } + + spd::debug("Exit get_single_dgm_bound,\ndual_cell = {}\nmodule = {}, dim = {}, result = {}, max_point = {}", dual_cell, module, dim, result, max_point); + + return result; + } + + Real DistanceCalculator::distance() + { + if (params_.dim != CalculationParams::ALL_DIMENSIONS) { + return distance_in_dimension_pq(params_.dim); + } else { + Real result = -1.0; + for(int d = 0; d <= maximal_dim_; ++d) { + result = std::max(result, distance_in_dimension_pq(d)); + } + return result; + } + } + + // calculate weighted bottleneneck distance between slices on line + // in dimension dim + // increments hera calls counter + Real DistanceCalculator::distance_on_line(int dim, DualPoint line) + { + // order matters - distance_on_line_const assumes n_hera_calls_ map has entry for dim + ++n_hera_calls_[dim]; + Real result = distance_on_line_const(dim, line); + return result; + } + + Real DistanceCalculator::distance_on_line_const(int dim, DualPoint line) const + { + // TODO: think about this - how to call Hera + Real hera_epsilon = 0.001; + auto dgm_a = module_a_.weighted_slice_diagram(line, dim).get_diagram(dim); + auto dgm_b = module_b_.weighted_slice_diagram(line, dim).get_diagram(dim); +// Real result = hera::bottleneckDistApprox(dgm_a, dgm_b, hera_epsilon); + Real result = hera::bottleneckDistExact(dgm_a, dgm_b); + if (n_hera_calls_.at(dim) % 100 == 1) { + spd::debug("Calling Hera, dgm_a.size = {}, dgm_b.size = {}, line = {}, result = {}", dgm_a.size(), dgm_b.size(), line, result); + } else { + spd::debug("Calling Hera, dgm_a.size = {}, dgm_b.size = {}, line = {}, result = {}", dgm_a.size(), dgm_b.size(), line, result); + } + return result; + } + + Real DistanceCalculator::get_good_enough_upper_bound(Real lower_bound) const + { + Real result; + // in upper_bound strategy we only prune cells if they cannot improve the lower bound, + // otherwise the experiment is supposed to run indefinitely + if (params_.traverse_strategy == TraverseStrategy::upper_bound) { + result = lower_bound; + } else { + result = (1.0 + params_.delta) * lower_bound; + } + return result; + } + + // helper function + // calculate weighted bt distance in dim on cell center, + // assign distance value to cell, keep it in heat_map, and return + void DistanceCalculator::set_cell_central_value(CellWithValue& dual_cell, int dim) + { + DualPoint central_line {dual_cell.center()}; + + spd::debug("In set_cell_central_value, processing dual cell = {}, line = {}", dual_cell.dual_box(), + central_line); + Real new_value = distance_on_line(dim, central_line); + n_hera_calls_per_level_[dual_cell.level() + 1]++; + dual_cell.set_value_at(ValuePoint::center, new_value); + params_.actual_max_depth = std::max(params_.actual_max_depth, dual_cell.level() + 1); + +#ifdef PRINT_HEAT_MAP + if (params_.bound_strategy == BoundStrategy::bruteforce) { + spd::debug("In set_cell_central_value, adding to heat_map pair {} - {}", dual_cell.center(), new_value); + if (dual_cell.level() > params_.initialization_depth + 1 + and params_.heat_maps[dual_cell.level()].count(dual_cell.center()) > 0) { + auto existing = params_.heat_maps[dual_cell.level()].find(dual_cell.center()); + spd::debug("EXISTING: {} -> {}", existing->first, existing->second); + } + assert(dual_cell.level() <= params_.initialization_depth + 1 + or params_.heat_maps[dual_cell.level()].count(dual_cell.center()) == 0); + params_.heat_maps[dual_cell.level()][dual_cell.center()] = new_value; + } +#endif + } + + // quick-and-dirty hack to efficiently traverse priority queue with dual cells + // returns maximal possible value on all cells in queue + // assumes that the underlying container is vector! + // cell_ptr: pointer to the first element in queue + // n_cells: queue size + Real DistanceCalculator::get_max_possible_value(const CellWithValue* cell_ptr, int n_cells) + { + Real result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0; + for(int i = 0; i < n_cells; ++i, ++cell_ptr) { + result = std::max(result, cell_ptr->stored_upper_bound()); + } + return result; + } + + // helper function: + // return current error from lower and upper bounds + // and save it in params_ (hence not const) + Real DistanceCalculator::current_error(Real lower_bound, Real upper_bound) + { + Real current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound + : std::numeric_limits<Real>::max(); + + params_.actual_error = current_error; + + if (current_error < params_.delta) { + spd::debug( + "Threshold achieved! bound_strategy = {}, traverse_strategy = {}, upper_bound = {}, current_error = {}", + params_.bound_strategy, params_.traverse_strategy, upper_bound, current_error); + } + return current_error; + } + + struct UbExperimentRecord { + Real error; + Real lower_bound; + Real upper_bound; + CellWithValue cell; + long long int time; + long long int n_hera_calls; + }; + + std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r); + + // return matching distance in dimension dim + // use priority queue to store dual cells + // comparison function depends on the strategies in params_ + // ressets hera calls counter + Real DistanceCalculator::distance_in_dimension_pq(int dim) + { + std::map<int, long> n_cells_considered; + std::map<int, long> n_cells_pushed_into_queue; + long int n_too_deep_cells = 0; + std::map<int, long> n_cells_discarded; + std::map<int, long> n_cells_pruned; + + spd::info("Enter distance_in_dimension_pq, dim = {}, bound strategy = {}, traverse strategy = {}, stop_asap = {} ", dim, params_.bound_strategy, params_.traverse_strategy, params_.stop_asap); + + std::chrono::high_resolution_clock timer; + auto start_time = timer.now(); + + n_hera_calls_[dim] = 0; + n_hera_calls_per_level_.clear(); + + + // if cell is too deep and is not pushed into queue, + // we still need to take its max value into account; + // the max over such cells is stored in max_result_on_too_fine_cells + Real upper_bound_on_deep_cells = -1; + + spd::debug("Started iterations in dual space, delta = {}, bound_strategy = {}", params_.delta, params_.bound_strategy); + // user-defined less lambda function + // to regulate priority queue depending on strategy + auto dual_cell_less = [this](const CellWithValue& a, const CellWithValue& b) { + + int a_level = a.level(); + int b_level = b.level(); + Real a_value = a.max_corner_value(); + Real b_value = b.max_corner_value(); + Real a_ub = a.stored_upper_bound(); + Real b_ub = b.stored_upper_bound(); + if (this->params_.traverse_strategy == TraverseStrategy::upper_bound and + (not a.has_max_possible_value() or not b.has_max_possible_value())) { + throw std::runtime_error("no upper bound on cell"); + } + DualPoint a_lower_left = a.dual_box().lower_left(); + DualPoint b_lower_left = b.dual_box().lower_left(); + + switch(this->params_.traverse_strategy) { + // in both breadth_first searches we want coarser cells + // to be processed first. Cells with smaller level must be larger, + // hence the minus in front of level + case TraverseStrategy::breadth_first: + return std::make_tuple(-a_level, a_lower_left) + < std::make_tuple(-b_level, b_lower_left); + case TraverseStrategy::breadth_first_value: + return std::make_tuple(-a_level, a_value, a_lower_left) + < std::make_tuple(-b_level, b_value, b_lower_left); + case TraverseStrategy::depth_first: + return std::make_tuple(a_value, a_level, a_lower_left) + < std::make_tuple(b_value, b_level, b_lower_left); + case TraverseStrategy::upper_bound: + return std::make_tuple(a_ub, a_level, a_lower_left) + < std::make_tuple(b_ub, b_level, b_lower_left); + default: + throw std::runtime_error("Forgotten case"); + } + }; + + std::priority_queue<CellWithValue, CellValueVector, decltype(dual_cell_less)> dual_cells_queue( + dual_cell_less); + + // weighted bt distance on the center of current cell + Real lower_bound = std::numeric_limits<Real>::min(); + + // init pq and lower bound + for(auto& init_cell : get_initial_dual_grid(lower_bound)) { + dual_cells_queue.push(init_cell); + } + + Real upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size()); + + std::vector<UbExperimentRecord> ub_experiment_results; + + while(not dual_cells_queue.empty()) { + + CellWithValue dual_cell = dual_cells_queue.top(); + dual_cells_queue.pop(); + assert(dual_cell.has_corner_value() + and dual_cell.has_max_possible_value() + and dual_cell.max_corner_value() <= upper_bound); + + n_cells_considered[dual_cell.level()]++; + + bool discard_cell = false; + + if (not params_.stop_asap) { + // if stop_asap is on, it is safer to never discard a cell + if (params_.bound_strategy == BoundStrategy::bruteforce) { + discard_cell = false; + } else if (params_.traverse_strategy == TraverseStrategy::upper_bound) { + discard_cell = (dual_cell.stored_upper_bound() <= lower_bound); + } else { + discard_cell = (dual_cell.stored_upper_bound() <= (1.0 + params_.delta) * lower_bound); + } + } + + spd::debug("CURRENT CELL bound_strategy = {}, traverse_strategy = {}, dual cell: {}, upper_bound = {}, lower_bound = {}, current_error = {}, discard_cell = {}", + params_.bound_strategy, params_.traverse_strategy, dual_cell, upper_bound, lower_bound, current_error(lower_bound, upper_bound), discard_cell); + + if (discard_cell) { + n_cells_discarded[dual_cell.level()]++; + continue; + } + + // until now, dual_cell knows its value in one of its corners + // new_value will be the weighted distance at its center + set_cell_central_value(dual_cell, dim); + Real new_value = dual_cell.value_at(ValuePoint::center); + lower_bound = std::max(new_value, lower_bound); + + spd::debug("Processed cell = {}, weighted value = {}, lower_bound = {}", dual_cell, new_value, lower_bound); + + assert(upper_bound >= lower_bound); + + if (current_error(lower_bound, upper_bound) < params_.delta) { + break; + } + + // refine cell and push 4 smaller cells into queue + for(auto refined_cell : dual_cell.get_refined_cells()) { + + if (refined_cell.num_values() == 0) + throw std::runtime_error("no value on cell"); + + // if delta is smaller than good_enough_value, it allows to prune cell + Real good_enough_ub = get_good_enough_upper_bound(lower_bound); + + // upper bound of the parent holds for refined_cell + // and can sometimes be smaller! + Real upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(), + get_upper_bound(refined_cell, dim, good_enough_ub)); + + spd::debug("upper_bound_on_refined_cell = {}, dual_cell.stored_upper_bound = {}, get_upper_bound = {}", + upper_bound_on_refined_cell, dual_cell.stored_upper_bound(), get_upper_bound(refined_cell, dim, good_enough_ub)); + + refined_cell.set_max_possible_value(upper_bound_on_refined_cell); + +#ifdef MD_DO_FULL_CHECK + check_upper_bound(refined_cell, dim); +#endif + + bool prune_cell = false; + + if (refined_cell.level() <= params_.max_depth) { + // cell might be added to queue; if it is not added, its maximal value can be safely ignored + if (params_.traverse_strategy == TraverseStrategy::upper_bound) { + prune_cell = (refined_cell.stored_upper_bound() <= lower_bound); + } else if (params_.bound_strategy != BoundStrategy::bruteforce) { + prune_cell = (refined_cell.stored_upper_bound() <= (1.0 + params_.delta) * lower_bound); + } + if (prune_cell) + n_cells_pruned[refined_cell.level()]++; +// prune_cell = (max_result_on_refined_cell <= lower_bound); + } else { + // cell is too deep, it won't be added to queue + // we must memorize maximal value on this cell, because we won't see it anymore + prune_cell = true; + if (refined_cell.stored_upper_bound() > (1 + params_.delta) * lower_bound) { + n_too_deep_cells++; + } + upper_bound_on_deep_cells = std::max(upper_bound_on_deep_cells, refined_cell.stored_upper_bound()); + } + + spd::debug("In distance_in_dimension_pq, loop over refined cells, bound_strategy = {}, traverse_strategy = {}, refined cell: {}, max_value_on_cell = {}, upper_bound = {}, current_error = {}, prune_cell = {}", + params_.bound_strategy, params_.traverse_strategy, refined_cell, refined_cell.stored_upper_bound(), upper_bound, current_error(lower_bound, upper_bound), prune_cell); + + if (not prune_cell) { + n_cells_pushed_into_queue[refined_cell.level()]++; + dual_cells_queue.push(refined_cell); + } + } // end loop over refined cells + + if (dual_cells_queue.empty()) + upper_bound = std::max(upper_bound, upper_bound_on_deep_cells); + else + upper_bound = std::max(upper_bound_on_deep_cells, + get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size())); + + if (params_.traverse_strategy == TraverseStrategy::upper_bound) { + upper_bound = dual_cells_queue.top().stored_upper_bound(); + + if (get_hera_calls_number(params_.dim) < 20 || get_hera_calls_number(params_.dim) % 20 == 0) { + auto elapsed = timer.now() - start_time; + UbExperimentRecord ub_exp_record; + + ub_exp_record.error = current_error(lower_bound, upper_bound); + ub_exp_record.lower_bound = lower_bound; + ub_exp_record.upper_bound = upper_bound; + ub_exp_record.cell = dual_cells_queue.top(); + ub_exp_record.n_hera_calls = n_hera_calls_[dim]; + ub_exp_record.time = std::chrono::duration_cast<std::chrono::milliseconds>(elapsed).count(); + +#ifdef MD_DO_CHECKS + if (ub_experiment_results.size() > 0) { + auto prev = ub_experiment_results.back(); + if (upper_bound > prev.upper_bound) { + spd::error("ALARM 1, upper_bound = {}, top = {}, prev.ub = {}, prev cell = {}, lower_bound = {}, prev.lower_bound = {}", + upper_bound, ub_exp_record.cell, prev.upper_bound, prev.cell, lower_bound, prev.lower_bound); + throw std::runtime_error("die"); + } + + if (lower_bound < prev.lower_bound) { + spd::error("ALARM 2, lower_bound = {}, prev.lower_bound = {}, top = {}, prev.ub = {}, prev cell = {}", lower_bound, prev.lower_bound, ub_exp_record.cell, prev.upper_bound, prev.cell); + throw std::runtime_error("die"); + } + } +#endif + + ub_experiment_results.emplace_back(ub_exp_record); + + fmt::print(std::cerr, "[UB_EXPERIMENT]\t{}\n", ub_exp_record); + } + } + + assert(upper_bound >= lower_bound); + + if (current_error(lower_bound, upper_bound) < params_.delta) { + break; + } + } + + params_.actual_error = current_error(lower_bound, upper_bound); + + if (n_too_deep_cells > 0) { + spd::warn("Error not guaranteed, there were {} too deep cells. Actual error = {}. Increase max_depth or delta", n_too_deep_cells, params_.actual_error); + } + // otherwise actual_error in params can be larger than delta, + // but this is OK + + spd::info("#############################################################"); + spd::info("Exiting distance_in_dimension_pq, bound_strategy = {}, traverse_strategy = {}, lower_bound = {}, upper_bound = {}, current_error = {}, actual_max_level = {}", + params_.bound_strategy, params_.traverse_strategy, lower_bound, + upper_bound, params_.actual_error, params_.actual_max_depth); + + spd::info("#############################################################"); + + bool print_stats = true; + if (print_stats) { + fmt::print("EXIT STATS, cells considered:\n"); + print_map(n_cells_considered); + fmt::print("EXIT STATS, cells discarded:\n"); + print_map(n_cells_discarded); + fmt::print("EXIT STATS, cells pruned:\n"); + print_map(n_cells_pruned); + fmt::print("EXIT STATS, cells pushed:\n"); + print_map(n_cells_pushed_into_queue); + fmt::print("EXIT STATS, hera calls:\n"); + print_map(n_hera_calls_per_level_); + + fmt::print("EXIT STATS, too deep cells with high value: {}\n", n_too_deep_cells); + } + + return lower_bound; + } + + int DistanceCalculator::get_hera_calls_number(int dim) const + { + if (dim == CalculationParams::ALL_DIMENSIONS) + return std::accumulate(n_hera_calls_.begin(), n_hera_calls_.end(), 0, + [](auto x, auto y) { return x + y.second; }); + else + return n_hera_calls_.at(dim); + } + + Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b, + CalculationParams& params) + { + DistanceCalculator runner(bif_a, bif_b, params); + Real result = runner.distance(); + params.n_hera_calls = runner.get_hera_calls_number(params.dim); + return result; + } + + std::istream& operator>>(std::istream& is, BoundStrategy& s) + { + std::string ss; + is >> ss; + if (ss == "bruteforce") { + s = BoundStrategy::bruteforce; + } else if (ss == "local_grob") { + s = BoundStrategy::local_dual_bound; + } else if (ss == "local_combined") { + s = BoundStrategy::local_combined; + } else if (ss == "local_refined") { + s = BoundStrategy::local_dual_bound_refined; + } else if (ss == "local_for_each_point") { + s = BoundStrategy::local_dual_bound_for_each_point; + } else { + throw std::runtime_error("UNKNOWN BOUND STRATEGY"); + } + return is; + } + + BoundStrategy bs_from_string(std::string s) + { + std::stringstream ss(s); + BoundStrategy result; + ss >> result; + return result; + } + + TraverseStrategy ts_from_string(std::string s) + { + std::stringstream ss(s); + TraverseStrategy result; + ss >> result; + return result; + } + + std::istream& operator>>(std::istream& is, TraverseStrategy& s) + { + std::string ss; + is >> ss; + if (ss == "DFS") { + s = TraverseStrategy::depth_first; + } else if (ss == "BFS") { + s = TraverseStrategy::breadth_first; + } else if (ss == "BFS-VAL") { + s = TraverseStrategy::breadth_first_value; + } else if (ss == "UB") { + s = TraverseStrategy::upper_bound; + } else { + throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY"); + } + return is; + } + + std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r) + { + os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound; + return os; + } + + std::ostream& operator<<(std::ostream& os, const BoundStrategy& s) + { + switch(s) { + case BoundStrategy::bruteforce : + os << "bruteforce"; + break; + case BoundStrategy::local_dual_bound : + os << "local_grob"; + break; + case BoundStrategy::local_combined : + os << "local_combined"; + break; + case BoundStrategy::local_dual_bound_refined : + os << "local_refined"; + break; + case BoundStrategy::local_dual_bound_for_each_point : + os << "local_for_each_point"; + break; + default: + os << "FORGOTTEN BOUND STRATEGY"; + } + return os; + } + + std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s) + { + switch(s) { + case TraverseStrategy::depth_first : + os << "DFS"; + break; + case TraverseStrategy::breadth_first : + os << "BFS"; + break; + case TraverseStrategy::breadth_first_value : + os << "BFS-VAL"; + break; + case TraverseStrategy::upper_bound : + os << "UB"; + break; + default: + os << "FORGOTTEN TRAVERSE STRATEGY"; + } + return os; + } +} diff --git a/matching/src/persistence_module.cpp b/matching/src/persistence_module.cpp new file mode 100644 index 0000000..e947925 --- /dev/null +++ b/matching/src/persistence_module.cpp @@ -0,0 +1,104 @@ +#include<phat/boundary_matrix.h> +#include<phat/compute_persistence_pairs.h> + +#include "persistence_module.h" + +namespace md { + PersistenceModule::PersistenceModule(const std::string& /*fname*/) // read from file + : + generators_(), + relations_() + { + } + + Diagram PersistenceModule::slice_diagram(const DualPoint& /*line*/) + { + //Vector2D b_of_line(L.b, -L.b); + //for (int i = 0; i<(int) F.size(); i++) { + // Simplex_in_2D_filtration& curr_simplex = F[i]; + // Vector2D proj = push(curr_simplex.pos, L); + + // curr_simplex.v = L_2(proj - b_of_line); + // //std::cout << proj << std::endl; + // //std::cout << "v=" << curr_simplex.v << std::endl; + //} + //std::sort(F.begin(), F.end(), sort_functor); + //std::map<Index, Index> index_map; + //for (Index i = 0; i<(int) F.size(); i++) { + // index_map[F[i].index] = i; + // //std::cout << F[i].index << " -> " << i << std::endl; + //} + //phat::boundary_matrix<> phat_matrix; + //phat_matrix.set_num_cols(F.size()); + //std::vector<Index> bd_in_slice_filtration; + //for (Index i = 0; i<(int) F.size(); i++) { + // phat_matrix.set_dim(i, F[i].dim); + // bd_in_slice_filtration.clear(); + // //std::cout << "new col" << i << std::endl; + // for (int j = 0; j<(int) F[i].bd.size(); j++) { + // // F[i] contains the indices of its facet wrt to the + // // original filtration. We have to express it, however, + // // wrt to the filtration along the slice. That is why + // // we need the index_map + // //std::cout << "Found " << F[i].bd[j] << ", returning " << index_map[F[i].bd[j]] << std::endl; + // bd_in_slice_filtration.push_back(index_map[F[i].bd[j]]); + // } + // std::sort(bd_in_slice_filtration.begin(), bd_in_slice_filtration.end()); + // [> + // for(int j=0;j<bd_in_slice_filtration.size();j++) { + // std::cout << bd_in_slice_filtration[j] << " "; + // } + // std::cout << std::endl; + // */ + // phat_matrix.set_col(i, bd_in_slice_filtration); + + //} + //phat::persistence_pairs phat_persistence_pairs; + //phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix); + + //dgm.clear(); + //for (long i = 0; i<(long) phat_persistence_pairs.get_num_pairs(); i++) { + // std::pair<phat::index, phat::index> new_pair = phat_persistence_pairs.get_pair(i); + // double birth = F[new_pair.first].v; + // double death = F[new_pair.second].v; + // if (birth!=death) { + // dgm.push_back(std::make_pair(birth, death)); + // } + //} + + //[> + //std::cout << "Done, created diagram: " << std::endl; + //for(int i=0;i<(int)dgm.size();i++) { + // std::cout << dgm[i].first << " " << dgm[i].second << std::endl; + //} + //*/ + return Diagram(); + + } + + PersistenceModule::Box PersistenceModule::bounding_box() const + { + Real ll_x = std::numeric_limits<Real>::max(); + Real ll_y = std::numeric_limits<Real>::max(); + Real ur_x = -std::numeric_limits<Real>::max(); + Real ur_y = -std::numeric_limits<Real>::max(); + + for(const auto& gen : generators_) { + ll_x = std::min(gen.x, ll_x); + ll_y = std::min(gen.y, ll_y); + ur_x = std::max(gen.x, ur_x); + ur_y = std::max(gen.y, ur_y); + } + + for(const auto& rel : relations_) { + + ll_x = std::min(rel.get_x(), ll_x); + ll_y = std::min(rel.get_y(), ll_y); + + ur_x = std::max(rel.get_x(), ur_x); + ur_y = std::max(rel.get_y(), ur_y); + } + + return Box(Point(ll_x, ll_y), Point(ur_x, ur_y)); + } +} diff --git a/matching/src/simplex.cpp b/matching/src/simplex.cpp new file mode 100644 index 0000000..c5cdd25 --- /dev/null +++ b/matching/src/simplex.cpp @@ -0,0 +1,123 @@ +#include "simplex.h" + +namespace md { + + std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s) + { + os << "AbstractSimplex(id = " << s.id << ", vertices_ = " << container_to_string(s.vertices_) << ")"; + return os; + } + + bool operator<(const AbstractSimplex& a, const AbstractSimplex& b) + { + return a.vertices_ < b.vertices_; + } + + bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2) + { + return s1.vertices_ == s2.vertices_; + } + + void AbstractSimplex::push_back(int v) + { + vertices_.push_back(v); + std::sort(vertices_.begin(), vertices_.end()); + } + + AbstractSimplex::AbstractSimplex(std::vector<int> vertices, bool sort) + :vertices_(vertices) + { + if (sort) + std::sort(vertices_.begin(), vertices_.end()); + } + + std::vector<AbstractSimplex> AbstractSimplex::facets() const + { + std::vector<AbstractSimplex> result; + for (int i = 0; i < static_cast<int>(vertices_.size()); ++i) { + std::vector<int> facet_vertices; + facet_vertices.reserve(dim()); + for (int j = 0; j < static_cast<int>(vertices_.size()); ++j) { + if (j != i) + facet_vertices.push_back(vertices_[j]); + } + if (!facet_vertices.empty()) { + result.emplace_back(facet_vertices, false); + } + } + return result; + } + + Simplex::Simplex(md::Index id, md::Point birth, int dim, const md::Column& bdry) + : + id_(id), + pos_(birth), + dim_(dim), + facet_indices_(bdry) { } + + void Simplex::translate(Real a) + { + pos_.x += a; + pos_.y += a; + } + + void Simplex::init_rivet(std::string s) + { +// throw std::runtime_error("Not implemented"); + auto delim_pos = s.find_first_of(";"); + assert(delim_pos > 0); + std::string vertices_str = s.substr(0, delim_pos); + std::string pos_str = s.substr(delim_pos + 1); + assert(not vertices_str.empty() and not pos_str.empty()); + // get vertices + std::stringstream vertices_ss(vertices_str); + int dim = 0; + int vertex; + while (vertices_ss >> vertex) { + dim++; + vertices_.push_back(vertex); + } + // + std::sort(vertices_.begin(), vertices_.end()); + assert(dim > 0); + + std::stringstream pos_ss(pos_str); + // TODO: get rid of 1-criticaltiy assumption + pos_ss >> pos_.x >> pos_.y; + } + + void Simplex::init_rene(std::string s) + { + facet_indices_.clear(); + std::stringstream ss(s); + ss >> dim_ >> pos_.x >> pos_.y; + if (dim_ > 0) { + facet_indices_.reserve(dim_ + 1); + for (int j = 0; j <= dim_; j++) { + Index k; + ss >> k; + facet_indices_.push_back(k); + } + } + } + + Simplex::Simplex(Index _id, std::string s, BifiltrationFormat input_format) + :id_(_id) + { + switch (input_format) { + case BifiltrationFormat::rene : + init_rene(s); + break; + case BifiltrationFormat::rivet : + init_rivet(s); + break; + } + } + + std::ostream& operator<<(std::ostream& os, const Simplex& x) + { + os << "Simplex(id = " << x.id() << ", dim = " << x.dim(); + os << ", boundary = " << container_to_string(x.boundary()) << ", pos = " << x.position() << ")"; + return os; + } +} diff --git a/matching/src/test_generator.cpp b/matching/src/test_generator.cpp new file mode 100644 index 0000000..a0d8fc7 --- /dev/null +++ b/matching/src/test_generator.cpp @@ -0,0 +1,208 @@ +#include <map> +#include <vector> +#include <random> +#include <iostream> +#include <algorithm> + +#include "opts/opts.h" +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +#include "common_util.h" +#include "bifiltration.h" + +using Index = md::Index; +using Point = md::Point; +using Column = md::Column; + +int g_max_coord = 100; + +using ASimplex = md::AbstractSimplex; + +using ASimplexToBirthMap = std::map<ASimplex, Point>; + +namespace spd = spdlog; + +// random generator is global +std::random_device rd; +std::mt19937_64 gen(rd()); + +//std::mt19937_64 gen(42); + +Point get_random_position(int max_coord) +{ + assert(max_coord > 0); + std::uniform_int_distribution<int> distr(0, max_coord); + return Point(distr(gen), distr(gen)); + +} + +Point get_random_position_less_than(Point ub) +{ + std::uniform_int_distribution<int> distr_x(0, ub.x); + std::uniform_int_distribution<int> distr_y(0, ub.y); + return Point(distr_x(gen), distr_y(gen)); +} + +Point get_random_position_greater_than(Point lb) +{ + std::uniform_int_distribution<int> distr_x(lb.x, g_max_coord); + std::uniform_int_distribution<int> distr_y(lb.y, g_max_coord); + return Point(distr_x(gen), distr_y(gen)); +} + +// non-proper faces (empty and simplex itself) are also faces +bool is_face(const ASimplex& face_candidate, const ASimplex& coface_candidate) +{ + return std::includes(coface_candidate.begin(), coface_candidate.end(), face_candidate.begin(), + face_candidate.end()); +} + +bool is_top_simplex(const ASimplex& candidate_simplex, const std::vector<ASimplex>& current_top_simplices) +{ + return std::none_of(current_top_simplices.begin(), current_top_simplices.end(), + [&candidate_simplex](const ASimplex& ts) { return is_face(candidate_simplex, ts); }); +} + +void add_if_top(const ASimplex& candidate_simplex, std::vector<ASimplex>& current_top_simplices) +{ + // check that candidate simplex is not face of someone in current_top_simplices + if (!is_top_simplex(candidate_simplex, current_top_simplices)) + return; + + spd::debug("candidate_simplex is top, will be added to top_simplices"); + // remove s from currrent_top_simplices, if s is face of candidate_simplex + current_top_simplices.erase(std::remove_if(current_top_simplices.begin(), current_top_simplices.end(), + [candidate_simplex](const ASimplex& s) { return is_face(s, candidate_simplex); }), + current_top_simplices.end()); + + current_top_simplices.push_back(candidate_simplex); +} + +ASimplex get_random_simplex(int n_vertices, int dim) +{ + std::vector<int> all_vertices(n_vertices, 0); + // fill in with 0..n-1 + std::iota(all_vertices.begin(), all_vertices.end(), 0); + std::shuffle(all_vertices.begin(), all_vertices.end(), gen); + return ASimplex(all_vertices.begin(), all_vertices.begin() + dim + 1, true); +} + +void generate_positions(const ASimplex& s, ASimplexToBirthMap& simplex_to_birth, Point upper_bound) +{ + auto pos = get_random_position_less_than(upper_bound); + auto curr_pos_iter = simplex_to_birth.find(s); + if (curr_pos_iter != simplex_to_birth.end()) + pos = md::greatest_lower_bound(pos, curr_pos_iter->second); + simplex_to_birth[s] = pos; + for(const ASimplex& facet : s.facets()) { + generate_positions(facet, simplex_to_birth, pos); + } +} + +md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_simplices) +{ + ASimplexToBirthMap simplex_to_birth; + + // generate vertices + for(int i = 0; i < n_vertices; ++i) { + Point vertex_birth = get_random_position(g_max_coord / 10); + ASimplex vertex; + vertex.push_back(i); + simplex_to_birth[vertex] = vertex_birth; + } + + std::vector<ASimplex> top_simplices; + // generate top simplices + while((int)top_simplices.size() < n_top_simplices) { + std::uniform_int_distribution<int> dimension_distr(1, max_dim); + int dim = dimension_distr(gen); + auto candidate_simplex = get_random_simplex(n_vertices, dim); + spd::debug("candidate_simplex = {}", candidate_simplex); + add_if_top(candidate_simplex, top_simplices); + } + + Point upper_bound{static_cast<md::Real>(g_max_coord), static_cast<md::Real>(g_max_coord)}; + for(const auto& top_simplex : top_simplices) { + generate_positions(top_simplex, simplex_to_birth, upper_bound); + } + + std::vector<std::pair<ASimplex, Point>> simplex_birth_pairs{simplex_to_birth.begin(), simplex_to_birth.end()}; + std::vector<md::Column> boundaries{simplex_to_birth.size(), md::Column()}; + +// assign ids and save boundaries + int id = 0; + + for(int dim = 0; dim <= max_dim; ++dim) { + for(int i = 0; i < (int) simplex_birth_pairs.size(); ++i) { + ASimplex& simplex = simplex_birth_pairs[i].first; + if (simplex.dim() == dim) { + simplex.id = id++; + md::Column bdry; + for(auto& facet : simplex.facets()) { + auto facet_iter = std::find_if(simplex_birth_pairs.begin(), simplex_birth_pairs.end(), + [facet](const std::pair<ASimplex, Point>& sbp) { return facet == sbp.first; }); + assert(facet_iter != simplex_birth_pairs.end()); + assert(facet_iter->first.id >= 0); + bdry.push_back(facet_iter->first.id); + } + std::sort(bdry.begin(), bdry.end()); + boundaries[i] = bdry; + } + } + } + +// create vector of Simplex-es + std::vector<md::Simplex> simplices; + for(int i = 0; i < (int) simplex_birth_pairs.size(); ++i) { + int id = simplex_birth_pairs[i].first.id; + int dim = simplex_birth_pairs[i].first.dim(); + Point birth = simplex_birth_pairs[i].second; + Column bdry = boundaries[i]; + simplices.emplace_back(id, birth, dim, bdry); + } + +// sort by id + std::sort(simplices.begin(), simplices.end(), + [](const md::Simplex& s1, const md::Simplex& s2) { return s1.id() < s2.id(); }); + for(int i = 0; i < (int)simplices.size(); ++i) { + assert(simplices[i].id() == i); + assert(i == 0 || simplices[i].dim() >= simplices[i - 1].dim()); + } + + return md::Bifiltration(simplices.begin(), simplices.end()); +} + +int main(int argc, char** argv) +{ + spd::set_level(spd::level::info); + int n_vertices; + int max_dim; + int n_top_simplices; + std::string fname; + + using opts::Option; + using opts::PosOption; + opts::Options ops; + + bool help = false; + + ops >> Option('v', "n-vertices", n_vertices, "number of vertices") + >> Option('d', "max-dim", max_dim, "maximal dim") + >> Option('m', "max-coord", g_max_coord, "maximal coordinate") + >> Option('t', "n-top-simplices", n_top_simplices, "number of top simplices") + >> Option('h', "help", help, "show help message"); + + if (!ops.parse(argc, argv) || help || !(ops >> PosOption(fname))) { + std::cerr << "Usage: " << argv[0] << "\n" << ops << std::endl; + return 1; + } + + + auto bif1 = get_random_bifiltration(n_vertices, max_dim, n_top_simplices); + std::cout << "Generated bifiltration." << std::endl; + bif1.save(fname, md::BifiltrationFormat::rene); + std::cout << "Saved to file " << fname << std::endl; + return 0; +} + diff --git a/matching/src/tests/prism_1_lesnick.bif b/matching/src/tests/prism_1_lesnick.bif new file mode 100644 index 0000000..416c40c --- /dev/null +++ b/matching/src/tests/prism_1_lesnick.bif @@ -0,0 +1,27 @@ +26 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +1 1 0 1 0 +1 1 0 2 1 +1 0 0 3 1 +1 0 1 5 4 +1 0 0 4 1 +1 0 0 3 0 +1 0 0 5 0 +1 0 0 4 2 +1 0 0 5 2 +1 0 1 4 3 +1 1 0 2 0 +1 0 1 5 3 +2 1 1 13 10 7 +2 1 1 9 14 13 +2 1 1 8 11 6 +2 1 1 15 10 8 +2 1 1 14 12 16 +2 1 1 17 12 11 +2 4 0 7 16 6 +2 0 4 9 17 15 diff --git a/matching/src/tests/prism_2_lesnick.bif b/matching/src/tests/prism_2_lesnick.bif new file mode 100644 index 0000000..e11faa9 --- /dev/null +++ b/matching/src/tests/prism_2_lesnick.bif @@ -0,0 +1,28 @@ +27 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +0 0 0 +1 0 0 2 1 +1 0 0 1 0 +1 0 0 5 3 +1 0 0 3 1 +1 0 0 5 4 +1 0 0 2 0 +1 0 0 4 1 +1 0 0 3 2 +1 0 0 5 2 +1 0 0 4 3 +1 0 0 4 2 +1 0 0 3 0 +2 0 0 16 12 6 +2 0 0 10 14 16 +2 0 0 9 17 7 +2 0 0 15 12 9 +2 0 0 13 17 11 +2 0 0 8 14 13 +2 4 0 6 11 7 +2 0 4 10 8 15 +2 3 3 15 16 13 diff --git a/matching/src/tests/test_bifiltration.cpp b/matching/src/tests/test_bifiltration.cpp new file mode 100644 index 0000000..caa50ac --- /dev/null +++ b/matching/src/tests/test_bifiltration.cpp @@ -0,0 +1,36 @@ +#include "catch/catch.hpp" + +#include <sstream> +#include <iostream> + +#include "common_util.h" +#include "box.h" +#include "bifiltration.h" + +using namespace md; + +//TEST_CASE("Small check", "[bifiltration][dim2]") +//{ +// Bifiltration bif("/home/narn/code/matching_distance/code/src/tests/test_bifiltration_full_triangle_rene.txt", BifiltrationFormat::rene); +// auto simplices = bif.simplices(); +// bif.sanity_check(); +// +// REQUIRE( simplices.size() == 7 ); +// +// REQUIRE( simplices[0].dim() == 0 ); +// REQUIRE( simplices[1].dim() == 0 ); +// REQUIRE( simplices[2].dim() == 0 ); +// REQUIRE( simplices[3].dim() == 1 ); +// REQUIRE( simplices[4].dim() == 1 ); +// REQUIRE( simplices[5].dim() == 1 ); +// REQUIRE( simplices[6].dim() == 2); +// +// REQUIRE( simplices[0].position() == Point(0, 0)); +// REQUIRE( simplices[1].position() == Point(0, 0)); +// REQUIRE( simplices[2].position() == Point(0, 0)); +// REQUIRE( simplices[3].position() == Point(3, 1)); +// REQUIRE( simplices[6].position() == Point(30, 40)); +// +// Line line_1(Line::pi / 2.0, 0.0); +// auto dgm = bif.slice_diagram(line_1); +//} diff --git a/matching/src/tests/test_bifiltration_1.txt b/matching/src/tests/test_bifiltration_1.txt new file mode 120000 index 0000000..ddd23e9 --- /dev/null +++ b/matching/src/tests/test_bifiltration_1.txt @@ -0,0 +1 @@ +../../data/test_bifiltration_1.txt
\ No newline at end of file diff --git a/matching/src/tests/test_bifiltration_full_triangle_rene.txt b/matching/src/tests/test_bifiltration_full_triangle_rene.txt new file mode 120000 index 0000000..47f49fd --- /dev/null +++ b/matching/src/tests/test_bifiltration_full_triangle_rene.txt @@ -0,0 +1 @@ +../../data/test_bifiltration_full_triangle_rene.txt
\ No newline at end of file diff --git a/matching/src/tests/test_common.cpp b/matching/src/tests/test_common.cpp new file mode 100644 index 0000000..30465e4 --- /dev/null +++ b/matching/src/tests/test_common.cpp @@ -0,0 +1,190 @@ +#include "catch/catch.hpp" + +#include <sstream> +#include <iostream> +#include <string> + +#include "common_util.h" +#include "simplex.h" +#include "matching_distance.h" + +using namespace md; + +TEST_CASE("Rational", "[common_utils][rational]") +{ + // gcd + REQUIRE(gcd(10, 5) == 5); + REQUIRE(gcd(5, 10) == 5); + REQUIRE(gcd(5, 7) == 1); + REQUIRE(gcd(7, 5) == 1); + REQUIRE(gcd(13, 0) == 13); + REQUIRE(gcd(0, 13) == 13); + REQUIRE(gcd(16, 24) == 8); + REQUIRE(gcd(24, 16) == 8); + REQUIRE(gcd(16, 32) == 16); + REQUIRE(gcd(32, 16) == 16); + + + // reduce + REQUIRE(reduce({2, 1}) == std::make_pair(2, 1)); + REQUIRE(reduce({1, 2}) == std::make_pair(1, 2)); + REQUIRE(reduce({2, 2}) == std::make_pair(1, 1)); + REQUIRE(reduce({0, 2}) == std::make_pair(0, 1)); + REQUIRE(reduce({0, 20}) == std::make_pair(0, 1)); + REQUIRE(reduce({35, 49}) == std::make_pair(5, 7)); + REQUIRE(reduce({35, 25}) == std::make_pair(7, 5)); + + // midpoint + REQUIRE(midpoint(Rational {0, 1}, Rational {1, 2}) == std::make_pair(1, 4)); + REQUIRE(midpoint(Rational {1, 4}, Rational {1, 2}) == std::make_pair(3, 8)); + REQUIRE(midpoint(Rational {1, 2}, Rational {1, 2}) == std::make_pair(1, 2)); + REQUIRE(midpoint(Rational {1, 2}, Rational {1, 1}) == std::make_pair(3, 4)); + REQUIRE(midpoint(Rational {3, 7}, Rational {5, 14}) == std::make_pair(11, 28)); + + + // arithmetic + + REQUIRE(Rational(1, 2) + Rational(3, 5) == Rational(11, 10)); + REQUIRE(Rational(2, 5) - Rational(3, 10) == Rational(1, 10)); + REQUIRE(Rational(2, 3) * Rational(4, 7) == Rational(8, 21)); + REQUIRE(Rational(2, 3) * Rational(3, 2) == Rational(1)); + REQUIRE(Rational(2, 3) / Rational(3, 2) == Rational(4, 9)); + REQUIRE(Rational(1, 2) * Rational(3, 5) == Rational(3, 10)); + + // comparison + REQUIRE(Rational(100000, 2000000) < Rational(100001, 2000000)); + REQUIRE(!(Rational(100001, 2000000) < Rational(100000, 2000000))); + REQUIRE(!(Rational(100000, 2000000) < Rational(100000, 2000000))); + REQUIRE(Rational(-100000, 2000000) < Rational(100001, 2000000)); + REQUIRE(Rational(-100001, 2000000) < Rational(100000, 2000000)); +}; + +TEST_CASE("AbstractSimplex", "[abstract_simplex]") +{ + AbstractSimplex as; + REQUIRE(as.dim() == -1); + + as.push_back(1); + REQUIRE(as.dim() == 0); + REQUIRE(as.facets().size() == 0); + + as.push_back(0); + REQUIRE(as.dim() == 1); + REQUIRE(as.facets().size() == 2); + REQUIRE(as.facets()[0].dim() == 0); + REQUIRE(as.facets()[1].dim() == 0); + +} + +TEST_CASE("Vertical line", "[vertical_line]") +{ + // line x = 1 + DualPoint l_vertical(AxisType::x_type, AngleType::steep, 0, 1); + + REQUIRE(l_vertical.is_vertical()); + REQUIRE(l_vertical.is_steep()); + + Point p_1(0.5, 0.5); + Point p_2(1.5, 0.5); + Point p_3(1.5, 1.5); + Point p_4(0.5, 1.5); + Point p_5(1, 10); + + REQUIRE(l_vertical.x_from_y(10) == 1); + REQUIRE(l_vertical.x_from_y(-10) == 1); + REQUIRE(l_vertical.x_from_y(0) == 1); + + REQUIRE(not l_vertical.contains(p_1)); + REQUIRE(not l_vertical.contains(p_2)); + REQUIRE(not l_vertical.contains(p_3)); + REQUIRE(not l_vertical.contains(p_4)); + REQUIRE(l_vertical.contains(p_5)); + + REQUIRE(l_vertical.goes_below(p_1)); + REQUIRE(not l_vertical.goes_below(p_2)); + REQUIRE(not l_vertical.goes_below(p_3)); + REQUIRE(l_vertical.goes_below(p_4)); + + REQUIRE(not l_vertical.goes_above(p_1)); + REQUIRE(l_vertical.goes_above(p_2)); + REQUIRE(l_vertical.goes_above(p_3)); + REQUIRE(not l_vertical.goes_above(p_4)); + +} + +TEST_CASE("Horizontal line", "[horizontal_line]") +{ + // line y = 1 + DualPoint l_horizontal(AxisType::y_type, AngleType::flat, 0, 1); + + REQUIRE(l_horizontal.is_horizontal()); + REQUIRE(l_horizontal.is_flat()); + REQUIRE(l_horizontal.y_slope() == 0); + REQUIRE(l_horizontal.y_intercept() == 1); + + Point p_1(0.5, 0.5); + Point p_2(1.5, 0.5); + Point p_3(1.5, 1.5); + Point p_4(0.5, 1.5); + Point p_5(2, 1); + + REQUIRE((not l_horizontal.contains(p_1) and + not l_horizontal.contains(p_2) and + not l_horizontal.contains(p_3) and + not l_horizontal.contains(p_4) and + l_horizontal.contains(p_5))); + + REQUIRE(not l_horizontal.goes_below(p_1)); + REQUIRE(not l_horizontal.goes_below(p_2)); + REQUIRE(l_horizontal.goes_below(p_3)); + REQUIRE(l_horizontal.goes_below(p_4)); + REQUIRE(l_horizontal.goes_below(p_5)); + + REQUIRE(l_horizontal.goes_above(p_1)); + REQUIRE(l_horizontal.goes_above(p_2)); + REQUIRE(not l_horizontal.goes_above(p_3)); + REQUIRE(not l_horizontal.goes_above(p_4)); + REQUIRE(l_horizontal.goes_above(p_5)); +} + +TEST_CASE("Flat Line with positive slope", "[flat_line]") +{ + // line y = x / 2 + 1 + DualPoint l_flat(AxisType::y_type, AngleType::flat, 0.5, 1); + + REQUIRE(not l_flat.is_horizontal()); + REQUIRE(l_flat.is_flat()); + REQUIRE(l_flat.y_slope() == 0.5); + REQUIRE(l_flat.y_intercept() == 1); + + REQUIRE(l_flat.y_from_x(0) == 1); + REQUIRE(l_flat.y_from_x(1) == 1.5); + REQUIRE(l_flat.y_from_x(2) == 2); + + Point p_1(3, 2); + Point p_2(-2, 0.01); + Point p_3(0, 1.25); + Point p_4(6, 4.5); + Point p_5(2, 2); + + std::cout << "AHOY " << l_flat.y_from_x(p_2.x) << std::endl; + + REQUIRE((not l_flat.contains(p_1) and + not l_flat.contains(p_2) and + not l_flat.contains(p_3) and + not l_flat.contains(p_4) and + l_flat.contains(p_5))); + + REQUIRE(not l_flat.goes_below(p_1)); + REQUIRE(l_flat.goes_below(p_2)); + REQUIRE(l_flat.goes_below(p_3)); + REQUIRE(l_flat.goes_below(p_4)); + REQUIRE(l_flat.goes_below(p_5)); + + REQUIRE(l_flat.goes_above(p_1)); + REQUIRE(not l_flat.goes_above(p_2)); + REQUIRE(not l_flat.goes_above(p_3)); + REQUIRE(not l_flat.goes_above(p_4)); + REQUIRE(l_flat.goes_above(p_5)); + +} diff --git a/matching/src/tests/test_list.txt b/matching/src/tests/test_list.txt new file mode 100644 index 0000000..1984606 --- /dev/null +++ b/matching/src/tests/test_list.txt @@ -0,0 +1 @@ +prism_lesnick_1.bif prism_lesnick_2.bif 1.0 diff --git a/matching/src/tests/test_matching_distance.cpp b/matching/src/tests/test_matching_distance.cpp new file mode 100644 index 0000000..a54e18e --- /dev/null +++ b/matching/src/tests/test_matching_distance.cpp @@ -0,0 +1,146 @@ +#include "catch/catch.hpp" + +#include <sstream> +#include <iostream> +#include <string> + +#include "spdlog/spdlog.h" +#include "spdlog/fmt/ostr.h" + +#include "common_util.h" +#include "simplex.h" +#include "matching_distance.h" + +using namespace md; +namespace spd = spdlog; + +TEST_CASE("Different bounds", "[bounds]") +{ + std::vector<Simplex> simplices; + std::vector<Point> points; + + Real max_x = 10; + Real max_y = 20; + + int simplex_id = 0; + for(int i = 0; i <= max_x; ++i) { + for(int j = 0; j <= max_y; ++j) { + Point p(i, j); + simplices.emplace_back(simplex_id++, p, 0, Column()); + points.push_back(p); + } + } + + Bifiltration bif_a(simplices.begin(), simplices.end()); + Bifiltration bif_b(simplices.begin(), simplices.end()); + + CalculationParams params; + params.initialization_depth = 2; + + DistanceCalculator calc(bif_a, bif_b, params); + +// REQUIRE(calc.max_x_ == Approx(max_x)); +// REQUIRE(calc.max_y_ == Approx(max_y)); + + std::vector<DualBox> boxes; + + for(CellWithValue c : calc.get_refined_grid(5, false, false)) { + boxes.push_back(c.dual_box()); + } + + // fill in boxes and points + + for(DualBox db : boxes) { + Real local_bound = calc.get_local_dual_bound(db); + Real local_bound_refined = calc.get_local_refined_bound(db); + REQUIRE(local_bound >= local_bound_refined); + for(Point p : points) { + for(ValuePoint vp_a : k_corner_vps) { + CellWithValue dual_cell(db, 1); + DualPoint corner_a = dual_cell.value_point(vp_a); + Real wp_a = corner_a.weighted_push(p); + dual_cell.set_value_at(vp_a, wp_a); + Real point_bound = calc.get_max_displacement_single_point(dual_cell, vp_a, p); + for(ValuePoint vp_b : k_corner_vps) { + if (vp_b <= vp_a) + continue; + DualPoint corner_b = dual_cell.value_point(vp_b); + Real wp_b = corner_b.weighted_push(p); + Real diff = fabs(wp_a - wp_b); + if (not(point_bound <= Approx(local_bound_refined))) { + std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound + << ", refined local = " << local_bound_refined << std::endl; + spd::set_level(spd::level::debug); + calc.get_max_displacement_single_point(dual_cell, vp_a, p); + calc.get_local_refined_bound(db); + spd::set_level(spd::level::info); + } + + REQUIRE(point_bound <= Approx(local_bound_refined)); + REQUIRE(diff <= Approx(point_bound)); + REQUIRE(diff <= Approx(local_bound_refined)); + } + + for(DualPoint l_random : db.random_points(100)) { + Real wp_random = l_random.weighted_push(p); + Real diff = fabs(wp_a - wp_random); + if (not(diff <= Approx(point_bound))) { + if (db.critical_points(p).size() > 4) { + std::cerr << "ERROR interesting case" << std::endl; + } else { + std::cerr << "ERROR boring case" << std::endl; + } + spd::set_level(spd::level::debug); + l_random.weighted_push(p); + spd::set_level(spd::level::info); + std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound + << ", refined local = " << local_bound_refined; + std::cerr << ", random_dual_point = " << l_random << ", wp_random = " << wp_random + << ", diff = " << diff << std::endl; + } + REQUIRE(diff <= Approx(point_bound)); + } + } + } + } +} + +TEST_CASE("Bifiltrations from file", "[matching_distance][small_example][lesnick]") +{ + std::string fname_a, fname_b; + + fname_a = "/home/narn/code/matching_distance/code/python_scripts/prism_1_lesnick.bif"; + fname_b = "/home/narn/code/matching_distance/code/python_scripts/prism_2_lesnick.bif"; + + Bifiltration bif_a(fname_a, BifiltrationFormat::rene); + Bifiltration bif_b(fname_b, BifiltrationFormat::rene); + + CalculationParams params; + + std::vector<BoundStrategy> bound_strategies {BoundStrategy::local_combined, + BoundStrategy::local_dual_bound_refined}; + + std::vector<TraverseStrategy> traverse_strategies {TraverseStrategy::breadth_first, TraverseStrategy::depth_first}; + + std::vector<double> scaling_factors {10, 1.0}; + + for(auto bs : bound_strategies) { + for(auto ts : traverse_strategies) { + for(double lambda : scaling_factors) { + Bifiltration bif_a_copy(bif_a); + Bifiltration bif_b_copy(bif_b); + bif_a_copy.scale(lambda); + bif_b_copy.scale(lambda); + params.bound_strategy = bs; + params.traverse_strategy = ts; + params.max_depth = 7; + params.delta = 0.01; + params.dim = 1; + Real answer = matching_distance(bif_a_copy, bif_b_copy, params); + Real correct_answer = lambda * 1.0; + REQUIRE(fabs(answer - correct_answer) < lambda * 0.05); + } + } + } +} + diff --git a/matching/src/tests/tests_main.cpp b/matching/src/tests/tests_main.cpp new file mode 100644 index 0000000..1c77b13 --- /dev/null +++ b/matching/src/tests/tests_main.cpp @@ -0,0 +1,2 @@ +#define CATCH_CONFIG_MAIN +#include "catch/catch.hpp" |