summaryrefslogtreecommitdiff
path: root/matching/src
diff options
context:
space:
mode:
Diffstat (limited to 'matching/src')
-rw-r--r--matching/src/bifiltration.cpp309
-rw-r--r--matching/src/box.cpp61
-rw-r--r--matching/src/cell_with_value.cpp247
-rw-r--r--matching/src/common_util.cpp243
-rw-r--r--matching/src/dual_box.cpp194
-rw-r--r--matching/src/dual_point.cpp282
-rw-r--r--matching/src/main.cpp367
-rw-r--r--matching/src/matching_distance.cpp907
-rw-r--r--matching/src/persistence_module.cpp104
-rw-r--r--matching/src/simplex.cpp123
-rw-r--r--matching/src/test_generator.cpp208
-rw-r--r--matching/src/tests/prism_1_lesnick.bif27
-rw-r--r--matching/src/tests/prism_2_lesnick.bif28
-rw-r--r--matching/src/tests/test_bifiltration.cpp36
l---------matching/src/tests/test_bifiltration_1.txt1
l---------matching/src/tests/test_bifiltration_full_triangle_rene.txt1
-rw-r--r--matching/src/tests/test_common.cpp190
-rw-r--r--matching/src/tests/test_list.txt1
-rw-r--r--matching/src/tests/test_matching_distance.cpp146
-rw-r--r--matching/src/tests/tests_main.cpp2
20 files changed, 3477 insertions, 0 deletions
diff --git a/matching/src/bifiltration.cpp b/matching/src/bifiltration.cpp
new file mode 100644
index 0000000..429a9a8
--- /dev/null
+++ b/matching/src/bifiltration.cpp
@@ -0,0 +1,309 @@
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <cassert>
+
+#include<phat/boundary_matrix.h>
+#include<phat/compute_persistence_pairs.h>
+
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/fmt.h"
+#include "spdlog/fmt/ostr.h"
+
+#include "common_util.h"
+#include "bifiltration.h"
+
+namespace spd = spdlog;
+
+namespace md {
+
+ void Bifiltration::init()
+ {
+ Point lower_left = max_point();
+ Point upper_right = min_point();
+ for(const auto& simplex : simplices_) {
+ lower_left = greatest_lower_bound(lower_left, simplex.position());
+ upper_right = least_upper_bound(upper_right, simplex.position());
+ maximal_dim_ = std::max(maximal_dim_, simplex.dim());
+ }
+ bounding_box_ = Box(lower_left, upper_right);
+ }
+
+ Bifiltration::Bifiltration(const std::string& fname, BifiltrationFormat input_format)
+ {
+ std::ifstream ifstr{fname.c_str()};
+ if (!ifstr.good()) {
+ std::string error_message = fmt::format("Cannot open file {0}", fname);
+ std::cerr << error_message << std::endl;
+ throw std::runtime_error(error_message);
+ }
+
+ switch(input_format) {
+ case BifiltrationFormat::rivet :
+ rivet_format_reader(ifstr);
+ break;
+ case BifiltrationFormat::rene :
+ rene_format_reader(ifstr);
+ break;
+ }
+ init();
+ }
+
+ void Bifiltration::rivet_format_reader(std::ifstream& ifstr)
+ {
+ std::string s;
+ std::getline(ifstr, s);
+ assert(s == std::string("bifiltration"));
+ std::getline(ifstr, parameter_1_name_);
+ std::getline(ifstr, parameter_2_name_);
+
+ Index index = 0;
+ while(std::getline(ifstr, s)) {
+ if (not ignore_line(s))
+ simplices_.emplace_back(index++, s, BifiltrationFormat::rivet);
+ }
+ }
+
+ void Bifiltration::rene_format_reader(std::ifstream& ifstr)
+ {
+ spd::debug("Enter rene_format_reader");
+ // read stream line by line; do not use >> operator
+ std::string s;
+ std::getline(ifstr, s);
+ long n_simplices = std::stol(s);
+
+ for(Index index = 0; index < n_simplices; index++) {
+ std::getline(ifstr, s);
+ simplices_.emplace_back(index, s, BifiltrationFormat::rene);
+ }
+ spd::debug("Read {} simplices from file", n_simplices);
+ }
+
+ void Bifiltration::scale(Real lambda)
+ {
+ for(auto& s : simplices_) {
+ s.scale(lambda);
+ }
+ init();
+ }
+
+ void Bifiltration::sanity_check() const
+ {
+#ifdef DEBUG
+ spd::debug("Enter Bifiltration::sanity_check");
+ // check that boundary has correct number of simplices,
+ // each bounding simplex has correct dim
+ // and appears in the filtration before the simplex it bounds
+ for(const auto& s : simplices_) {
+ assert(s.dim() >= 0);
+ assert(s.dim() == 0 or s.dim() + 1 == (int) s.boundary().size());
+ for(auto bdry_idx : s.boundary()) {
+ Simplex bdry_simplex = simplices()[bdry_idx];
+ assert(bdry_simplex.dim() == s.dim() - 1);
+ assert(bdry_simplex.position().is_less(s.position(), false));
+ }
+ }
+ spd::debug("Exit Bifiltration::sanity_check");
+#endif
+ }
+
+ DiagramKeeper Bifiltration::weighted_slice_diagram(const DualPoint& line, int /*dim*/) const
+ {
+ DiagramKeeper dgm;
+
+ // make a copy for now; I want slice_diagram to be const
+ std::vector<Simplex> simplices(simplices_);
+
+// std::vector<Simplex> simplices;
+// simplices.reserve(simplices_.size() / 2);
+// for(const auto& s : simplices_) {
+// if (s.dim() <= dim + 1 and s.dim() >= dim)
+// simplices.emplace_back(s);
+// }
+
+ spd::debug("Enter slice diagram, line = {}, simplices.size = {}", line, simplices.size());
+
+ for(auto& simplex : simplices) {
+ Real value = line.weighted_push(simplex.position());
+// spd::debug("in slice_diagram, simplex = {}, value = {}\n", simplex, value);
+ simplex.set_value(value);
+ }
+
+ std::sort(simplices.begin(), simplices.end(),
+ [](const Simplex& a, const Simplex& b) { return a.value() < b.value(); });
+ std::map<Index, Index> index_map;
+ for(Index i = 0; i < (int) simplices.size(); i++) {
+ index_map[simplices[i].id()] = i;
+ }
+
+ phat::boundary_matrix<> phat_matrix;
+ phat_matrix.set_num_cols(simplices.size());
+ std::vector<Index> bd_in_slice_filtration;
+ for(Index i = 0; i < (int) simplices.size(); i++) {
+ phat_matrix.set_dim(i, simplices[i].dim());
+ bd_in_slice_filtration.clear();
+ //std::cout << "new col" << i << std::endl;
+ for(int j = 0; j < (int) simplices[i].boundary().size(); j++) {
+ // F[i] contains the indices of its facet wrt to the
+ // original filtration. We have to express it, however,
+ // wrt to the filtration along the slice. That is why
+ // we need the index_map
+ //std::cout << "Found " << F[i].bd[j] << ", returning " << index_map[F[i].bd[j]] << std::endl;
+ bd_in_slice_filtration.push_back(index_map[simplices[i].boundary()[j]]);
+ }
+ std::sort(bd_in_slice_filtration.begin(), bd_in_slice_filtration.end());
+ phat_matrix.set_col(i, bd_in_slice_filtration);
+ }
+ phat::persistence_pairs phat_persistence_pairs;
+ phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix);
+
+ dgm.clear();
+ for(long i = 0; i < (long) phat_persistence_pairs.get_num_pairs(); i++) {
+ std::pair<phat::index, phat::index> new_pair = phat_persistence_pairs.get_pair(i);
+ Real birth = simplices.at(new_pair.first).value();
+ Real death = simplices.at(new_pair.second).value();
+ int dim = simplices[new_pair.first].dim();
+ assert(dim + 1 == simplices[new_pair.second].dim());
+ if (birth != death) {
+ dgm.add_point(dim, birth, death);
+ }
+ }
+
+ spdlog::debug("Exiting slice_diagram, #dgm[0] = {}", dgm.get_diagram(0).size());
+
+ return dgm;
+ }
+
+ Box Bifiltration::bounding_box() const
+ {
+ return bounding_box_;
+ }
+
+ Real Bifiltration::minimal_coordinate() const
+ {
+ return std::min(bounding_box_.lower_left().x, bounding_box_.lower_left().y);
+ }
+
+ void Bifiltration::translate(Real a)
+ {
+ bounding_box_.translate(a);
+ for(auto& simplex : simplices_) {
+ simplex.translate(a);
+ }
+ }
+
+ Real Bifiltration::max_x() const
+ {
+ if (simplices_.empty())
+ return 1;
+ auto me = std::max_element(simplices_.cbegin(), simplices_.cend(),
+ [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; });
+ assert(me != simplices_.cend());
+ return me->position().x;
+ }
+
+ Real Bifiltration::max_y() const
+ {
+ if (simplices_.empty())
+ return 1;
+ auto me = std::max_element(simplices_.cbegin(), simplices_.cend(),
+ [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; });
+ assert(me != simplices_.cend());
+ return me->position().y;
+ }
+
+ Real Bifiltration::min_x() const
+ {
+ if (simplices_.empty())
+ return 0;
+ auto me = std::min_element(simplices_.cbegin(), simplices_.cend(),
+ [](const auto& s_a, const auto& s_b) { return s_a.position().x < s_b.position().x; });
+ assert(me != simplices_.cend());
+ return me->position().x;
+ }
+
+ Real Bifiltration::min_y() const
+ {
+ if (simplices_.empty())
+ return 0;
+ auto me = std::min_element(simplices_.cbegin(), simplices_.cend(),
+ [](const auto& s_a, const auto& s_b) { return s_a.position().y < s_b.position().y; });
+ assert(me != simplices_.cend());
+ return me->position().y;
+ }
+
+ void Bifiltration::add_simplex(md::Index _id, md::Point birth, int _dim, const md::Column& _bdry)
+ {
+ simplices_.emplace_back(_id, birth, _dim, _bdry);
+ }
+
+ void Bifiltration::save(const std::string& filename, md::BifiltrationFormat format)
+ {
+ switch(format) {
+ case BifiltrationFormat::rivet:
+ throw std::runtime_error("Not implemented");
+ break;
+ case BifiltrationFormat::rene: {
+ std::ofstream f(filename);
+ if (not f.good()) {
+ std::cerr << "Bifiltration::save: cannot open file " << filename << std::endl;
+ throw std::runtime_error("Cannot open file for writing ");
+ }
+ f << simplices_.size() << "\n";
+
+ for(const auto& s : simplices_) {
+ f << s.dim() << " " << s.position().x << " " << s.position().y << " ";
+ for(int b : s.boundary()) {
+ f << b << " ";
+ }
+ f << std::endl;
+ }
+
+ }
+ break;
+ }
+ }
+
+
+ void Bifiltration::postprocess_rivet_format()
+ {
+ std::map<Column, Index> facets_to_ids;
+
+ // fill the map
+ for(Index i = 0; i < (Index)simplices_.size(); ++i) {
+ assert(simplices_[i].id() == i);
+ facets_to_ids[simplices_[i].vertices_] = i;
+ }
+
+// for(const auto& s : simplices_) {
+// facets_to_ids[s] = s.id();
+// }
+
+ // main loop
+ for(auto& s : simplices_) {
+ assert(not s.vertices_.empty());
+ assert(s.facet_indices_.empty());
+ Column facet_indices;
+ for(Index i = 0; i <= s.dim(); ++i) {
+ Column facet;
+ for(Index j : s.vertices_) {
+ if (j != i)
+ facet.push_back(j);
+ }
+ auto facet_index = facets_to_ids.at(facet);
+ facet_indices.push_back(facet_index);
+ }
+ s.facet_indices_ = facet_indices;
+ } // loop over simplices
+ }
+
+ std::ostream& operator<<(std::ostream& os, const Bifiltration& bif)
+ {
+ os << "Bifiltration, axes = " << bif.parameter_1_name_ << ", " << bif.parameter_2_name_ << std::endl;
+ for(const auto& s : bif.simplices()) {
+ os << s << std::endl;
+ }
+ return os;
+ }
+}
+
diff --git a/matching/src/box.cpp b/matching/src/box.cpp
new file mode 100644
index 0000000..c128698
--- /dev/null
+++ b/matching/src/box.cpp
@@ -0,0 +1,61 @@
+
+#include "box.h"
+
+namespace md {
+
+ std::ostream& operator<<(std::ostream& os, const Box& box)
+ {
+ os << "Box(lower_left = " << box.lower_left() << ", upper_right = " << box.upper_right() << ")";
+ return os;
+ }
+
+ Box get_enclosing_box(const Box& box_a, const Box& box_b)
+ {
+ Point lower_left(std::min(box_a.lower_left().x, box_b.lower_left().x),
+ std::min(box_a.lower_left().y, box_b.lower_left().y));
+ Point upper_right(std::max(box_a.upper_right().x, box_b.upper_right().x),
+ std::max(box_a.upper_right().y, box_b.upper_right().y));
+ return Box(lower_left, upper_right);
+ }
+
+ void Box::translate(md::Real a)
+ {
+ ll.x += a;
+ ll.y += a;
+ ur.x += a;
+ ur.y += a;
+ }
+
+ std::vector<Box> Box::refine() const
+ {
+ std::vector<Box> result;
+
+// 1 | 2
+// 0 | 3
+
+ Point new_ll = lower_left();
+ Point new_ur = center();
+ result.emplace_back(new_ll, new_ur);
+
+ new_ll.y = center().y;
+ new_ur.y = ur.y;
+ result.emplace_back(new_ll, new_ur);
+
+ new_ll = center();
+ new_ur = upper_right();
+ result.emplace_back(new_ll, new_ur);
+
+ new_ll.y = ll.y;
+ new_ur.y = center().y;
+ result.emplace_back(new_ll, new_ur);
+
+ return result;
+ }
+
+ std::vector<Point> Box::corners() const
+ {
+ return {ll, Point(ll.x, ur.y), ur, Point(ur.x, ll.y)};
+ };
+
+
+}
diff --git a/matching/src/cell_with_value.cpp b/matching/src/cell_with_value.cpp
new file mode 100644
index 0000000..d8fd7d4
--- /dev/null
+++ b/matching/src/cell_with_value.cpp
@@ -0,0 +1,247 @@
+#include <spdlog/spdlog.h>
+#include <spdlog/fmt/ostr.h>
+
+namespace spd = spdlog;
+
+#include "cell_with_value.h"
+
+namespace md {
+
+#ifdef MD_DEBUG
+ long long int CellWithValue::max_id = 0;
+#endif
+
+ Real CellWithValue::value_at(ValuePoint vp) const
+ {
+ switch(vp) {
+ case ValuePoint::upper_left :
+ return upper_left_value_;
+ case ValuePoint::upper_right :
+ return upper_right_value_;
+ case ValuePoint::lower_left :
+ return lower_left_value_;
+ case ValuePoint::lower_right :
+ return lower_right_value_;
+ case ValuePoint::center:
+ return central_value_;
+ }
+ // to shut up compiler warning
+ return 1.0 / 0.0;
+ }
+
+ bool CellWithValue::has_value_at(ValuePoint vp) const
+ {
+ switch(vp) {
+ case ValuePoint::upper_left :
+ return upper_left_value_ >= 0;
+ case ValuePoint::upper_right :
+ return upper_right_value_ >= 0;
+ case ValuePoint::lower_left :
+ return lower_left_value_ >= 0;
+ case ValuePoint::lower_right :
+ return lower_right_value_ >= 0;
+ case ValuePoint::center:
+ return central_value_ >= 0;
+ }
+ // to shut up compiler warning
+ return 1.0 / 0.0;
+ }
+
+ DualPoint CellWithValue::value_point(md::ValuePoint vp) const
+ {
+ switch(vp) {
+ case ValuePoint::upper_left :
+ return dual_box().upper_left();
+ case ValuePoint::upper_right :
+ return dual_box().upper_right();
+ case ValuePoint::lower_left :
+ return dual_box().lower_left();
+ case ValuePoint::lower_right :
+ return dual_box().lower_right();
+ case ValuePoint::center:
+ return dual_box().center();
+ }
+ // to shut up compiler warning
+ return DualPoint();
+ }
+
+ bool CellWithValue::has_corner_value() const
+ {
+ return has_lower_left_value() or has_lower_right_value() or has_upper_left_value()
+ or has_upper_right_value();
+ }
+
+ Real CellWithValue::stored_upper_bound() const
+ {
+ assert(has_max_possible_value_);
+ return max_possible_value_;
+ }
+
+ Real CellWithValue::max_corner_value() const
+ {
+ return std::max({lower_left_value_, lower_right_value_, upper_left_value_, upper_right_value_});
+ }
+
+ Real CellWithValue::min_value() const
+ {
+ Real result = std::numeric_limits<Real>::max();
+ for(auto vp : k_all_vps) {
+ if (not has_value_at(vp)) {
+ continue;
+ }
+ result = std::min(result, value_at(vp));
+ }
+ return result;
+ }
+
+ std::vector<CellWithValue> CellWithValue::get_refined_cells() const
+ {
+ std::vector<CellWithValue> result;
+ result.reserve(4);
+ for(const auto& refined_box : dual_box_.refine()) {
+
+ CellWithValue refined_cell(refined_box, level() + 1);
+
+#ifdef MD_DEBUG
+ refined_cell.parent_ids = parent_ids;
+ refined_cell.parent_ids.push_back(id);
+ refined_cell.id = ++max_id;
+#endif
+
+ if (refined_box.lower_left() == dual_box_.lower_left()) {
+ // _|_
+ // H|_
+
+ refined_cell.set_value_at(ValuePoint::lower_left, lower_left_value_);
+ refined_cell.set_value_at(ValuePoint::upper_right, central_value_);
+
+ } else if (refined_box.upper_right() == dual_box_.upper_right()) {
+ // _|H
+ // _|_
+
+ refined_cell.set_value_at(ValuePoint::lower_left, central_value_);
+ refined_cell.set_value_at(ValuePoint::upper_right, upper_right_value_);
+
+ } else if (refined_box.lower_right() == dual_box_.lower_right()) {
+ // _|_
+ // _|H
+
+ refined_cell.set_value_at(ValuePoint::lower_right, lower_right_value_);
+ refined_cell.set_value_at(ValuePoint::upper_left, central_value_);
+
+ } else if (refined_box.upper_left() == dual_box_.upper_left()) {
+
+ // H|_
+ // _|_
+
+ refined_cell.set_value_at(ValuePoint::lower_right, central_value_);
+ refined_cell.set_value_at(ValuePoint::upper_left, upper_left_value_);
+ }
+ result.emplace_back(refined_cell);
+ }
+ return result;
+ }
+
+ void CellWithValue::set_value_at(md::ValuePoint vp, md::Real new_value)
+ {
+ if (has_value_at(vp))
+ spd::error("CellWithValue: trying to re-assign value!, this = {}, vp = {}", *this, vp);
+
+ switch(vp) {
+ case ValuePoint::upper_left :
+ upper_left_value_ = new_value;
+ break;
+ case ValuePoint::upper_right :
+ upper_right_value_ = new_value;
+ break;
+ case ValuePoint::lower_left :
+ lower_left_value_ = new_value;
+ break;
+ case ValuePoint::lower_right :
+ lower_right_value_ = new_value;
+ break;
+ case ValuePoint::center:
+ central_value_ = new_value;
+ break;
+ }
+
+
+ }
+
+ int CellWithValue::num_values() const
+ {
+ int result = 0;
+ for(ValuePoint vp : k_all_vps) {
+ result += has_value_at(vp);
+ }
+ return result;
+ }
+
+
+ void CellWithValue::set_max_possible_value(Real new_upper_bound)
+ {
+ assert(new_upper_bound >= central_value_);
+ assert(new_upper_bound >= lower_left_value_);
+ assert(new_upper_bound >= lower_right_value_);
+ assert(new_upper_bound >= upper_left_value_);
+ assert(new_upper_bound >= upper_right_value_);
+ has_max_possible_value_ = true;
+ max_possible_value_ = new_upper_bound;
+ }
+
+ std::ostream& operator<<(std::ostream& os, const ValuePoint& vp)
+ {
+ switch(vp) {
+ case ValuePoint::upper_left :
+ os << "upper_left";
+ break;
+ case ValuePoint::upper_right :
+ os << "upper_right";
+ break;
+ case ValuePoint::lower_left :
+ os << "lower_left";
+ break;
+ case ValuePoint::lower_right :
+ os << "lower_right";
+ break;
+ case ValuePoint::center:
+ os << "center";
+ break;
+ default:
+ os << "FORGOTTEN ValuePoint";
+ }
+ return os;
+ }
+
+
+
+ std::ostream& operator<<(std::ostream& os, const CellWithValue& cell)
+ {
+ os << "CellWithValue(box = " << cell.dual_box() << ", ";
+
+#ifdef MD_DEBUG
+ os << "id = " << cell.id;
+ if (not cell.parent_ids.empty())
+ os << ", parent_ids = " << container_to_string(cell.parent_ids) << ", ";
+#endif
+
+ for(ValuePoint vp : k_all_vps) {
+ if (cell.has_value_at(vp)) {
+ os << "value = " << cell.value_at(vp);
+ os << ", at " << vp << " " << cell.value_point(vp);
+ }
+ }
+
+ os << ", max_corner_value = ";
+ if (cell.has_max_possible_value()) {
+ os << cell.stored_upper_bound();
+ } else {
+ os << "-";
+ }
+
+ os << ", level = " << cell.level() << ")";
+ return os;
+ }
+
+} // namespace md
+
diff --git a/matching/src/common_util.cpp b/matching/src/common_util.cpp
new file mode 100644
index 0000000..96c3388
--- /dev/null
+++ b/matching/src/common_util.cpp
@@ -0,0 +1,243 @@
+#include <vector>
+#include <utility>
+#include <cmath>
+#include <ostream>
+#include <limits>
+#include <algorithm>
+
+#include <common_util.h>
+
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/ostr.h"
+
+namespace md {
+
+
+ int gcd(int a, int b)
+ {
+ assert(a != 0 or b != 0);
+ // make b <= a
+ std::tie(b, a) = std::minmax({ abs(a), abs(b) });
+ if (b == 0)
+ return a;
+ while((a = a % b)) {
+ std::swap(a, b);
+ }
+ return b;
+ }
+
+ int signum(int a)
+ {
+ if (a < 0)
+ return -1;
+ else if (a > 0)
+ return 1;
+ else
+ return 0;
+ }
+
+ Rational reduce(Rational frac)
+ {
+ int d = gcd(frac.numerator, frac.denominator);
+ frac.numerator /= d;
+ frac.denominator /= d;
+ return frac;
+ }
+
+ void Rational::reduce() { *this = md::reduce(*this); }
+
+
+ Rational& Rational::operator*=(const md::Rational& rhs)
+ {
+ numerator *= rhs.numerator;
+ denominator *= rhs.denominator;
+ reduce();
+ return *this;
+ }
+
+ Rational& Rational::operator/=(const md::Rational& rhs)
+ {
+ numerator *= rhs.denominator;
+ denominator *= rhs.numerator;
+ reduce();
+ return *this;
+ }
+
+ Rational& Rational::operator+=(const md::Rational& rhs)
+ {
+ numerator = numerator * rhs.denominator + denominator * rhs.numerator;
+ denominator *= rhs.denominator;
+ reduce();
+ return *this;
+ }
+
+ Rational& Rational::operator-=(const md::Rational& rhs)
+ {
+ numerator = numerator * rhs.denominator - denominator * rhs.numerator;
+ denominator *= rhs.denominator;
+ reduce();
+ return *this;
+ }
+
+
+ Rational midpoint(Rational a, Rational b)
+ {
+ return reduce({a.numerator * b.denominator + a.denominator * b.numerator, 2 * a.denominator * b.denominator });
+ }
+
+ Rational operator+(Rational a, const Rational& b)
+ {
+ a += b;
+ return a;
+ }
+
+ Rational operator-(Rational a, const Rational& b)
+ {
+ a -= b;
+ return a;
+ }
+
+ Rational operator*(Rational a, const Rational& b)
+ {
+ a *= b;
+ return a;
+ }
+
+ Rational operator/(Rational a, const Rational& b)
+ {
+ a /= b;
+ return a;
+ }
+
+ bool is_less(Rational a, Rational b)
+ {
+ // compute a - b = a_1 / a_2 - b_1 / b_2
+ long numer = a.numerator * b.denominator - a.denominator * b.numerator;
+ long denom = a.denominator * b.denominator;
+ assert(denom != 0);
+ return signum(numer) * signum(denom) < 0;
+ }
+
+ bool operator==(const Rational& a, const Rational& b)
+ {
+ return std::tie(a.numerator, a.denominator) == std::tie(b.numerator, b.denominator);
+ }
+
+ bool operator<(const Rational& a, const Rational& b)
+ {
+ // do not remove signum - overflow
+ long numer = a.numerator * b.denominator - a.denominator * b.numerator;
+ long denom = a.denominator * b.denominator;
+ assert(denom != 0);
+// spdlog::debug("a = {}, b = {}, numer = {}, denom = {}, result = {}", a, b, numer, denom, signum(numer) * signum(denom) <= 0);
+ return signum(numer) * signum(denom) < 0;
+ }
+
+ bool is_leq(Rational a, Rational b)
+ {
+ // compute a - b = a_1 / a_2 - b_1 / b_2
+ long numer = a.numerator * b.denominator - a.denominator * b.numerator;
+ long denom = a.denominator * b.denominator;
+ assert(denom != 0);
+ return signum(numer) * signum(denom) <= 0;
+ }
+
+ bool is_greater(Rational a, Rational b)
+ {
+ return not is_leq(a, b);
+ }
+
+ bool is_geq(Rational a, Rational b)
+ {
+ return not is_less(a, b);
+ }
+
+ Point operator+(const Point& u, const Point& v)
+ {
+ return Point(u.x + v.x, u.y + v.y);
+ }
+
+ Point operator-(const Point& u, const Point& v)
+ {
+ return Point(u.x - v.x, u.y - v.y);
+ }
+
+ Point least_upper_bound(const Point& u, const Point& v)
+ {
+ return Point(std::max(u.x, v.x), std::max(u.y, v.y));
+ }
+
+ Point greatest_lower_bound(const Point& u, const Point& v)
+ {
+ return Point(std::min(u.x, v.x), std::min(u.y, v.y));
+ }
+
+ Point max_point()
+ {
+ return Point(std::numeric_limits<Real>::max(), std::numeric_limits<Real>::min());
+ }
+
+ Point min_point()
+ {
+ return Point(-std::numeric_limits<Real>::max(), -std::numeric_limits<Real>::min());
+ }
+
+ std::ostream& operator<<(std::ostream& ostr, const Point& vec)
+ {
+ ostr << "(" << vec.x << ", " << vec.y << ")";
+ return ostr;
+ }
+
+ Real l_infty_norm(const Point& v)
+ {
+ return std::max(std::abs(v.x), std::abs(v.y));
+ }
+
+ Real l_2_norm(const Point& v)
+ {
+ return v.norm();
+ }
+
+ Real l_2_dist(const Point& x, const Point& y)
+ {
+ return l_2_norm(x - y);
+ }
+
+ Real l_infty_dist(const Point& x, const Point& y)
+ {
+ return l_infty_norm(x - y);
+ }
+
+ void DiagramKeeper::add_point(int dim, md::Real birth, md::Real death)
+ {
+ data_[dim].emplace_back(birth, death);
+ }
+
+ DiagramKeeper::Diagram DiagramKeeper::get_diagram(int dim) const
+ {
+ if (data_.count(dim) == 1)
+ return data_.at(dim);
+ else
+ return DiagramKeeper::Diagram();
+ }
+
+ // return true, if line starts with #
+ // or contains only spaces
+ bool ignore_line(const std::string& s)
+ {
+ for(auto c : s) {
+ if (isspace(c))
+ continue;
+ return (c == '#');
+ }
+ return true;
+ }
+
+
+
+ std::ostream& operator<<(std::ostream& os, const Rational& a)
+ {
+ os << a.numerator << " / " << a.denominator;
+ return os;
+ }
+}
diff --git a/matching/src/dual_box.cpp b/matching/src/dual_box.cpp
new file mode 100644
index 0000000..f9d2979
--- /dev/null
+++ b/matching/src/dual_box.cpp
@@ -0,0 +1,194 @@
+#include <random>
+
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/ostr.h"
+
+namespace spd = spdlog;
+
+#include "dual_box.h"
+
+namespace md {
+
+ std::ostream& operator<<(std::ostream& os, const DualBox& db)
+ {
+ os << "DualBox(" << db.lower_left_ << ", " << db.upper_right_ << ")";
+ return os;
+ }
+
+ DualBox::DualBox(DualPoint ll, DualPoint ur)
+ :lower_left_(ll), upper_right_(ur)
+ {
+ }
+
+ std::vector<DualPoint> DualBox::corners() const
+ {
+ return {lower_left_,
+ DualPoint(axis_type(), angle_type(), lower_left_.lambda(), upper_right_.mu()),
+ upper_right_,
+ DualPoint(axis_type(), angle_type(), upper_right_.lambda(), lower_left_.mu())};
+ }
+
+ std::vector<DualPoint> DualBox::push_change_points(const Point& p) const
+ {
+ std::vector<DualPoint> result;
+ result.reserve(2);
+
+ bool is_y_type = lower_left_.is_y_type();
+ bool is_flat = lower_left_.is_flat();
+
+ auto mu_from_lambda = [p, is_y_type, is_flat](Real lambda) {
+ bool is_x_type = not is_y_type, is_steep = not is_flat;
+ if (is_y_type and is_flat) {
+ return p.y - lambda * p.x;
+ } else if (is_y_type and is_steep) {
+ return p.y - p.x / lambda;
+ } else if (is_x_type and is_flat) {
+ return p.x - p.y / lambda;
+ } else if (is_x_type and is_steep) {
+ return p.x - lambda * p.y;
+ }
+ // to shut up compiler warning
+ return static_cast<Real>(1.0 / 0.0);
+ };
+
+ auto lambda_from_mu = [p, is_y_type, is_flat](Real mu) {
+ bool is_x_type = not is_y_type, is_steep = not is_flat;
+ if (is_y_type and is_flat) {
+ return (p.y - mu) / p.x;
+ } else if (is_y_type and is_steep) {
+ return p.x / (p.y - mu);
+ } else if (is_x_type and is_flat) {
+ return p.y / (p.x - mu);
+ } else if (is_x_type and is_steep) {
+ return (p.x - mu) / p.y;
+ }
+ // to shut up compiler warning
+ return static_cast<Real>(1.0 / 0.0);
+ };
+
+ // all inequalities below are strict: equality means it is a corner
+ // and critical_points() returns corners anyway
+
+ Real mu_intersect_min = mu_from_lambda(lambda_min());
+
+ if (mu_min() < mu_intersect_min && mu_intersect_min < mu_max())
+ result.emplace_back(axis_type(), angle_type(), lambda_min(), mu_intersect_min);
+
+ Real mu_intersect_max = mu_from_lambda(lambda_max());
+
+ if (mu_max() < mu_intersect_max && mu_intersect_max < mu_max())
+ result.emplace_back(axis_type(), angle_type(), lambda_max(), mu_intersect_max);
+
+ Real lambda_intersect_min = lambda_from_mu(mu_min());
+
+ if (lambda_min() < lambda_intersect_min && lambda_intersect_min < lambda_max())
+ result.emplace_back(axis_type(), angle_type(), lambda_intersect_min, mu_min());
+
+ Real lambda_intersect_max = lambda_from_mu(mu_max());
+ if (lambda_min() < lambda_intersect_max && lambda_intersect_max < lambda_max())
+ result.emplace_back(axis_type(), angle_type(), lambda_intersect_max, mu_max());
+
+ assert(result.size() <= 2);
+
+ if (result.size() > 2) {
+ fmt::print("Error in push_change_points, p = {}, dual_box = {}, result = {}\n", p, *this,
+ container_to_string(result));
+ throw std::runtime_error("push_change_points returned more than 2 points");
+ }
+
+ return result;
+ }
+
+ std::vector<DualPoint> DualBox::critical_points(const Point& p) const
+ {
+ // maximal difference is attained at corners
+ return corners();
+// std::vector<DualPoint> result;
+// result.reserve(6);
+// for(auto dp : corners()) result.push_back(dp);
+// for(auto dp : push_change_points(p)) result.push_back(dp);
+// return result;
+ }
+
+ std::vector<DualPoint> DualBox::random_points(int n) const
+ {
+ assert(n >= 0);
+ std::mt19937_64 gen(1);
+ std::vector<DualPoint> result;
+ result.reserve(n);
+ std::uniform_real_distribution<Real> mu_distr(mu_min(), mu_max());
+ std::uniform_real_distribution<Real> lambda_distr(lambda_min(), lambda_max());
+ for(int i = 0; i < n; ++i) {
+ result.emplace_back(axis_type(), angle_type(), lambda_distr(gen), mu_distr(gen));
+ }
+ return result;
+ }
+
+ bool DualBox::sanity_check() const
+ {
+ lower_left_.sanity_check();
+ upper_right_.sanity_check();
+
+ if (lower_left_.angle_type() != upper_right_.angle_type())
+ throw std::runtime_error("angle types differ");
+
+ if (lower_left_.axis_type() != upper_right_.axis_type())
+ throw std::runtime_error("axis types differ");
+
+ if (lower_left_.lambda() >= upper_right_.lambda())
+ throw std::runtime_error("lambda of lower_left_ greater than lambda of upper_right ");
+
+ if (lower_left_.mu() >= upper_right_.mu())
+ throw std::runtime_error("mu of lower_left_ greater than mu of upper_right ");
+
+ return true;
+ }
+
+ std::vector<DualBox> DualBox::refine() const
+ {
+ std::vector<DualBox> result;
+
+ result.reserve(4);
+
+ Real lambda_middle = (lower_left().lambda() + upper_right().lambda()) / 2.0;
+ Real mu_middle = (lower_left().mu() + upper_right().mu()) / 2.0;
+
+ DualPoint refinement_center(axis_type(), angle_type(), lambda_middle, mu_middle);
+
+ result.emplace_back(lower_left_, refinement_center);
+
+ result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_middle, mu_min()),
+ DualPoint(axis_type(), angle_type(), lambda_max(), mu_middle));
+
+ result.emplace_back(refinement_center, upper_right_);
+
+ result.emplace_back(DualPoint(axis_type(), angle_type(), lambda_min(), mu_middle),
+ DualPoint(axis_type(), angle_type(), lambda_middle, mu_max()));
+ return result;
+ }
+
+ bool DualBox::operator==(const DualBox& other) const
+ {
+ return lower_left() == other.lower_left() and
+ upper_right() == other.upper_right();
+ }
+
+ bool DualBox::contains(const DualPoint& dp) const
+ {
+ return dp.angle_type() == angle_type() and dp.axis_type() == axis_type() and
+ mu_max() >= dp.mu() and
+ mu_min() <= dp.mu() and
+ lambda_min() <= dp.lambda() and
+ lambda_max() >= dp.lambda();
+ }
+
+ DualPoint DualBox::lower_right() const
+ {
+ return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_max(), mu_min());
+ }
+
+ DualPoint DualBox::upper_left() const
+ {
+ return DualPoint(lower_left_.axis_type(), lower_left_.angle_type(), lambda_min(), mu_max());
+ }
+}
diff --git a/matching/src/dual_point.cpp b/matching/src/dual_point.cpp
new file mode 100644
index 0000000..1c00b58
--- /dev/null
+++ b/matching/src/dual_point.cpp
@@ -0,0 +1,282 @@
+#include <tuple>
+
+#include "dual_point.h"
+
+namespace md {
+
+ std::ostream& operator<<(std::ostream& os, const AxisType& at)
+ {
+ if (at == AxisType::x_type)
+ os << "x-type";
+ else
+ os << "y-type";
+ return os;
+ }
+
+ std::ostream& operator<<(std::ostream& os, const AngleType& at)
+ {
+ if (at == AngleType::flat)
+ os << "flat";
+ else
+ os << "steep";
+ return os;
+ }
+
+ std::ostream& operator<<(std::ostream& os, const DualPoint& dp)
+ {
+ os << "Line(" << dp.axis_type() << ", ";
+ os << dp.angle_type() << ", ";
+ os << dp.lambda() << ", ";
+ os << dp.mu() << ", equation: ";
+ if (not dp.is_vertical()) {
+ os << "y = " << dp.y_slope() << " x + " << dp.y_intercept();
+ } else {
+ os << "x = " << dp.x_intercept();
+ }
+ os << " )";
+ return os;
+ }
+
+ bool DualPoint::operator<(const DualPoint& rhs) const
+ {
+ return std::tie(axis_type_, angle_type_, lambda_, mu_)
+ < std::tie(rhs.axis_type_, rhs.angle_type_, rhs.lambda_, rhs.mu_);
+ }
+
+ DualPoint::DualPoint(AxisType axis_type, AngleType angle_type, Real lambda, Real mu)
+ :
+ axis_type_(axis_type),
+ angle_type_(angle_type),
+ lambda_(lambda),
+ mu_(mu)
+ {
+ assert(sanity_check());
+ }
+
+ bool DualPoint::sanity_check() const
+ {
+ if (lambda_ < 0.0)
+ throw std::runtime_error("Invalid line, negative lambda");
+ if (lambda_ > 1.0)
+ throw std::runtime_error("Invalid line, lambda > 1");
+ if (mu_ < 0.0)
+ throw std::runtime_error("Invalid line, negative mu");
+ return true;
+ }
+
+ Real DualPoint::gamma() const
+ {
+ if (is_steep())
+ return atan(Real(1.0) / lambda_);
+ else
+ return atan(lambda_);
+ }
+
+ DualPoint midpoint(DualPoint x, DualPoint y)
+ {
+ assert(x.angle_type() == y.angle_type() and x.axis_type() == y.axis_type());
+ Real lambda_mid = (x.lambda() + y.lambda()) / 2;
+ Real mu_mid = (x.mu() + y.mu()) / 2;
+ return DualPoint(x.axis_type(), x.angle_type(), lambda_mid, mu_mid);
+
+ }
+
+ // return k in the line equation y = kx + b
+ Real DualPoint::y_slope() const
+ {
+ if (is_flat())
+ return lambda();
+ else
+ return Real(1.0) / lambda();
+ }
+
+ // return k in the line equation x = ky + b
+ Real DualPoint::x_slope() const
+ {
+ if (is_flat())
+ return Real(1.0) / lambda();
+ else
+ return lambda();
+ }
+
+ // return b in the line equation y = kx + b
+ Real DualPoint::y_intercept() const
+ {
+ if (is_y_type()) {
+ return mu();
+ } else {
+ // x = x_slope * y + mu = x_slope * (y + mu / x_slope)
+ // x-intercept is -mu/x_slope = -mu * y_slope
+ return -mu() * y_slope();
+ }
+ }
+
+ // return k in the line equation x = ky + b
+ Real DualPoint::x_intercept() const
+ {
+ if (is_x_type()) {
+ return mu();
+ } else {
+ // y = y_slope * x + mu = y_slope (x + mu / y_slope)
+ // x_intercept is -mu/y_slope = -mu * x_slope
+ return -mu() * x_slope();
+ }
+ }
+
+ Real DualPoint::x_from_y(Real y) const
+ {
+ if (is_horizontal())
+ throw std::runtime_error("x_from_y called on horizontal line");
+ else
+ return x_slope() * y + x_intercept();
+ }
+
+ Real DualPoint::y_from_x(Real x) const
+ {
+ if (is_vertical())
+ throw std::runtime_error("x_from_y called on horizontal line");
+ else
+ return y_slope() * x + y_intercept();
+ }
+
+ bool DualPoint::is_horizontal() const
+ {
+ return is_flat() and lambda() == 0;
+ }
+
+ bool DualPoint::is_vertical() const
+ {
+ return is_steep() and lambda() == 0;
+ }
+
+ bool DualPoint::contains(Point p) const
+ {
+ if (is_vertical())
+ return p.x == x_from_y(p.y);
+ else
+ return p.y == y_from_x(p.x);
+ }
+
+ bool DualPoint::goes_below(Point p) const
+ {
+ if (is_vertical())
+ return p.x <= x_from_y(p.y);
+ else
+ return p.y >= y_from_x(p.x);
+ }
+
+ bool DualPoint::goes_above(Point p) const
+ {
+ if (is_vertical())
+ return p.x >= x_from_y(p.y);
+ else
+ return p.y <= y_from_x(p.x);
+ }
+
+ Point DualPoint::push(Point p) const
+ {
+ Point result;
+ // if line is below p, we push horizontally
+ bool horizontal_push = goes_below(p);
+ if (is_x_type()) {
+ if (is_flat()) {
+ if (horizontal_push) {
+ result.x = p.y / lambda() + mu();
+ result.y = p.y;
+ } else {
+ // vertical push
+ result.x = p.x;
+ result.y = lambda() * (p.x - mu());
+ }
+ } else {
+ // steep
+ if (horizontal_push) {
+ result.x = lambda() * p.y + mu();
+ result.y = p.y;
+ } else {
+ // vertical push
+ result.x = p.x;
+ result.y = (p.x - mu()) / lambda();
+ }
+ }
+ } else {
+ // y-type
+ if (is_flat()) {
+ if (horizontal_push) {
+ result.x = (p.y - mu()) / lambda();
+ result.y = p.y;
+ } else {
+ // vertical push
+ result.x = p.x;
+ result.y = lambda() * p.x + mu();
+ }
+ } else {
+ // steep
+ if (horizontal_push) {
+ result.x = (p.y - mu()) * lambda();
+ result.y = p.y;
+ } else {
+ // vertical push
+ result.x = p.x;
+ result.y = p.x / lambda() + mu();
+ }
+ }
+ }
+ return result;
+ }
+
+ Real DualPoint::weighted_push(Point p) const
+ {
+ // if line is below p, we push horizontally
+ bool horizontal_push = goes_below(p);
+ if (is_x_type()) {
+ if (is_flat()) {
+ if (horizontal_push) {
+ return p.y;
+ } else {
+ // vertical push
+ return lambda() * (p.x - mu());
+ }
+ } else {
+ // steep
+ if (horizontal_push) {
+ return lambda() * p.y;
+ } else {
+ // vertical push
+ return (p.x - mu());
+ }
+ }
+ } else {
+ // y-type
+ if (is_flat()) {
+ if (horizontal_push) {
+ return p.y - mu();
+ } else {
+ // vertical push
+ return lambda() * p.x;
+ }
+ } else {
+ // steep
+ if (horizontal_push) {
+ return lambda() * (p.y - mu());
+ } else {
+ // vertical push
+ return p.x;
+ }
+ }
+ }
+ }
+
+ bool DualPoint::operator==(const DualPoint& other) const
+ {
+ return axis_type() == other.axis_type() and
+ angle_type() == other.angle_type() and
+ mu() == other.mu() and
+ lambda() == other.lambda();
+ }
+
+ Real DualPoint::weight() const
+ {
+ return lambda_ / sqrt(1 + lambda_ * lambda_);
+ }
+} // namespace md
diff --git a/matching/src/main.cpp b/matching/src/main.cpp
new file mode 100644
index 0000000..34c4240
--- /dev/null
+++ b/matching/src/main.cpp
@@ -0,0 +1,367 @@
+#include "common_defs.h"
+
+#include <iostream>
+#include <string>
+#include <cassert>
+#include <experimental/filesystem>
+
+#ifdef EXPERIMENTAL_TIMING
+#include <chrono>
+#endif
+
+#include "opts/opts.h"
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/ostr.h"
+
+//#include "persistence_module.h"
+#include "bifiltration.h"
+#include "box.h"
+#include "matching_distance.h"
+
+using namespace md;
+
+namespace fs = std::experimental::filesystem;
+
+void print_heat_map(const md::HeatMaps& hms, std::string fname, const CalculationParams& params)
+{
+#ifdef PRINT_HEAT_MAP
+ spd::debug("Entered print_heat_map");
+ std::set<Real> mu_vals, lambda_vals;
+ auto hm_iter = hms.end();
+ --hm_iter;
+ int max_level = hm_iter->first;
+
+ int level_cardinality = 4;
+ for(int i = 0; i < params.initialization_depth; ++i) {
+ level_cardinality *= 4;
+ }
+ for(int i = params.initialization_depth + 1; i <= max_level; ++i) {
+ spd::debug("hms.at({}).size = {}, must be {}", i, hms.at(i).size(), level_cardinality);
+ assert(static_cast<decltype(level_cardinality)>(hms.at(i).size()) == level_cardinality);
+ level_cardinality *= 4;
+ }
+
+ std::map<std::pair<Real, Real>, Real> hm_x_flat, hm_x_steep, hm_y_flat, hm_y_steep;
+
+ for(const auto& dual_point_value_pair : hms.at(max_level)) {
+ const DualPoint& k = dual_point_value_pair.first;
+ spd::debug("HM DP: {}", k);
+ mu_vals.insert(k.mu());
+ lambda_vals.insert(k.lambda());
+ }
+
+ std::vector<Real> lambda_vals_vec(lambda_vals.begin(), lambda_vals.end());
+ std::vector<Real> mu_vals_vec(mu_vals.begin(), mu_vals.end());
+
+ std::ofstream ofs {fname};
+ if (not ofs.good()) {
+ std::cerr << "Cannot write heat map to file " << fname << std::endl;
+ throw std::runtime_error("Cannot open file for writing heat map");
+ }
+
+ std::vector<std::vector<Real>> heatmap_to_print(2 * mu_vals_vec.size(),
+ std::vector<Real>(2 * lambda_vals_vec.size(), 0.0));
+
+ for(auto axis_type : {AxisType::x_type, AxisType::y_type}) {
+ bool is_x_type = axis_type == AxisType::x_type;
+ for(auto angle_type : {AngleType::flat, AngleType::steep}) {
+ bool is_flat = angle_type == AngleType::flat;
+
+ int mu_idx_begin, mu_idx_end;
+
+ if (is_x_type) {
+ mu_idx_begin = mu_vals_vec.size() - 1;
+ mu_idx_end = -1;
+ } else {
+ mu_idx_begin = 0;
+ mu_idx_end = mu_vals_vec.size();
+ }
+
+ int lambda_idx_begin, lambda_idx_end;
+
+ if (is_flat) {
+ lambda_idx_begin = 0;
+ lambda_idx_end = lambda_vals_vec.size();
+ } else {
+ lambda_idx_begin = lambda_vals_vec.size() - 1;
+ lambda_idx_end = -1;
+ }
+
+ int mu_idx_final = is_x_type ? 0 : mu_vals_vec.size();
+
+ for(int mu_idx = mu_idx_begin; mu_idx != mu_idx_end; (mu_idx_begin < mu_idx_end) ? mu_idx++ : mu_idx--) {
+ Real mu = mu_vals_vec.at(mu_idx);
+
+ if (mu == 0.0 and axis_type == AxisType::x_type)
+ continue;
+
+ int lambda_idx_final = is_flat ? 0 : lambda_vals_vec.size();
+
+ for(int lambda_idx = lambda_idx_begin;
+ lambda_idx != lambda_idx_end;
+ (lambda_idx_begin < lambda_idx_end) ? lambda_idx++ : lambda_idx--) {
+
+ Real lambda = lambda_vals_vec.at(lambda_idx);
+
+ if (lambda == 0.0 and angle_type == AngleType::flat)
+ continue;
+
+ DualPoint dp(axis_type, angle_type, lambda, mu);
+ Real dist_value = hms.at(max_level).at(dp);
+
+ heatmap_to_print.at(mu_idx_final).at(lambda_idx_final) = dist_value;
+
+// fmt::print("HM, dp = {}, mu_idx_final = {}, lambda_idx_final = {}, value = {}\n", dp, mu_idx_final,
+// lambda_idx_final, dist_value);
+
+ lambda_idx_final++;
+ }
+ mu_idx_final++;
+ }
+ }
+ }
+
+ for(size_t m_idx = 0; m_idx < heatmap_to_print.size(); ++m_idx) {
+ for(size_t l_idx = 0; l_idx < heatmap_to_print.at(m_idx).size(); ++l_idx) {
+ ofs << heatmap_to_print.at(m_idx).at(l_idx) << " ";
+ }
+ ofs << std::endl;
+ }
+
+ ofs.close();
+ spd::debug("Exit print_heat_map");
+#endif
+}
+
+int main(int argc, char** argv)
+{
+ spdlog::set_level(spdlog::level::info);
+ //spdlog::set_pattern("[%L] %v");
+
+ using opts::Option;
+ using opts::PosOption;
+ opts::Options ops;
+
+ bool help = false;
+ bool heatmap_only = false;
+ bool no_stop_asap = false;
+ CalculationParams params;
+
+ std::string bounds_list_str = "local_combined";
+ std::string traverse_list_str = "BFS";
+
+ ops >> Option('m', "max-iterations", params.max_depth, "maximal number of iterations (refinements)")
+ >> Option('e', "max-error", params.delta, "error threshold")
+ >> Option('d', "dim", params.dim, "dim")
+ >> Option('i', "initial-depth", params.initialization_depth, "initialization depth")
+ >> Option("no-stop-asap", no_stop_asap,
+ "don't stop looping over points, if cell cannot be pruned (asap is on by default)")
+ >> Option("bounds", bounds_list_str, "bounds to use, separated by ,")
+ >> Option("traverse", traverse_list_str, "traverse to use, separated by ,")
+#ifdef PRINT_HEAT_MAP
+ >> Option("heatmap-only", heatmap_only, "only save heatmap (bruteforce)")
+#endif
+ >> Option('h', "help", help, "show help message");
+
+ std::string fname_a;
+ std::string fname_b;
+
+ if (!ops.parse(argc, argv) || help || !(ops >> PosOption(fname_a) >> PosOption(fname_b))) {
+ std::cerr << "Usage: " << argv[0] << "bifiltration-file-1 bifiltration-file-2\n" << ops << std::endl;
+ return 1;
+ }
+
+ params.stop_asap = not no_stop_asap;
+
+ auto bounds_list = split_by_delim(bounds_list_str, ',');
+ auto traverse_list = split_by_delim(traverse_list_str, ',');
+
+ Bifiltration bif_a(fname_a, BifiltrationFormat::rene);
+ Bifiltration bif_b(fname_b, BifiltrationFormat::rene);
+
+ bif_a.sanity_check();
+ bif_b.sanity_check();
+
+ spd::info("Read bifiltrations {} {}", fname_a, fname_b);
+
+ std::vector<BoundStrategy> bound_strategies;
+ std::vector<TraverseStrategy> traverse_strategies;
+
+ for(std::string s : bounds_list) {
+ bound_strategies.push_back(bs_from_string(s));
+ }
+
+ for(std::string s : traverse_list) {
+ traverse_strategies.push_back(ts_from_string(s));
+ }
+
+ for(auto bs : bound_strategies) {
+ for(auto ts : traverse_strategies) {
+ spd::info("Will test combination {} {}", bs, ts);
+ }
+ }
+
+#ifdef EXPERIMENTAL_TIMING
+ struct ExperimentResult {
+ CalculationParams params {CalculationParams()};
+ int n_hera_calls {0};
+ double total_milliseconds_elapsed {0};
+ double distance {0};
+ double actual_error {std::numeric_limits<double>::max()};
+ int actual_max_depth {0};
+
+ int x_wins {0};
+ int y_wins {0};
+ int ad_wins {0};
+
+ int seconds_elapsed() const
+ {
+ return static_cast<int>(total_milliseconds_elapsed / 1000);
+ }
+
+ double savings_ratio_old() const
+ {
+ long int max_possible_calls = 0;
+ long int calls_on_level = 4;
+ for(int i = 0; i <= actual_max_depth; ++i) {
+ max_possible_calls += calls_on_level;
+ calls_on_level *= 4;
+ }
+ return static_cast<double>(n_hera_calls) / static_cast<double>(max_possible_calls);
+ }
+
+ double savings_ratio() const
+ {
+ return static_cast<double>(n_hera_calls) / calls_on_actual_max_depth();
+ }
+
+ long long int calls_on_actual_max_depth() const
+ {
+ long long int result = 1;
+ for(int i = 0; i < actual_max_depth; ++i) {
+ result *= 4;
+ }
+ return result;
+ }
+
+ ExperimentResult() { }
+
+ ExperimentResult(CalculationParams p, int nhc, double tme, double d)
+ :
+ params(p), n_hera_calls(nhc), total_milliseconds_elapsed(tme), distance(d) { }
+ };
+
+ const int n_repetitions = 1;
+
+ if (heatmap_only) {
+ bound_strategies.clear();
+ bound_strategies.push_back(BoundStrategy::bruteforce);
+ traverse_strategies.clear();
+ traverse_strategies.push_back(TraverseStrategy::breadth_first);
+ }
+
+ std::map<std::tuple<BoundStrategy, TraverseStrategy>, ExperimentResult> results;
+ for(BoundStrategy bound_strategy : bound_strategies) {
+ for(TraverseStrategy traverse_strategy : traverse_strategies) {
+ CalculationParams params_experiment;
+ params_experiment.bound_strategy = bound_strategy;
+ params_experiment.traverse_strategy = traverse_strategy;
+ params_experiment.max_depth = params.max_depth;
+ params_experiment.initialization_depth = params.initialization_depth;
+ params_experiment.delta = params.delta;
+ params_experiment.dim = params.dim;
+ params_experiment.hera_epsilon = params.hera_epsilon;
+ params_experiment.stop_asap = params.stop_asap;
+
+ if (traverse_strategy == TraverseStrategy::depth_first and bound_strategy == BoundStrategy::bruteforce)
+ continue;
+
+ // if bruteforce, clamp max iterations number to 7,
+ // save user-provided max_iters in user_max_iters and restore it later.
+ // remember: params is passed by reference to return real relative error and heat map
+
+ int user_max_iters = params.max_depth;
+ if (bound_strategy == BoundStrategy::bruteforce and not heatmap_only) {
+ params_experiment.max_depth = std::min(7, params.max_depth);
+ }
+ double total_milliseconds_elapsed = 0;
+ int total_n_hera_calls = 0;
+ Real dist;
+ for(int i = 0; i < n_repetitions; ++i) {
+ spd::debug("Processing bound_strategy {}, traverse_strategy {}, iteration = {}", bound_strategy,
+ traverse_strategy, i);
+ auto t1 = std::chrono::high_resolution_clock().now();
+ dist = matching_distance(bif_a, bif_b, params_experiment);
+ auto t2 = std::chrono::high_resolution_clock().now();
+ total_milliseconds_elapsed += std::chrono::duration_cast<std::chrono::milliseconds>(
+ t2 - t1).count();
+ total_n_hera_calls += params_experiment.n_hera_calls;
+ }
+
+ auto key = std::make_tuple(bound_strategy, traverse_strategy);
+ results[key].params = params_experiment;
+ results[key].n_hera_calls = total_n_hera_calls / n_repetitions;
+ results[key].total_milliseconds_elapsed = total_milliseconds_elapsed / n_repetitions;
+ results[key].distance = dist;
+ results[key].actual_error = params_experiment.actual_error;
+ results[key].actual_max_depth = params_experiment.actual_max_depth;
+
+ spd::info(
+ "Done (bound = {}, traverse = {}), n_hera_calls = {}, time = {} sec, d = {}, error = {}, savings = {}, max_depth = {}",
+ bound_strategy, traverse_strategy, results[key].n_hera_calls, results[key].seconds_elapsed(),
+ dist,
+ params_experiment.actual_error, results[key].savings_ratio(), results[key].actual_max_depth);
+
+ if (bound_strategy == BoundStrategy::bruteforce) { params_experiment.max_depth = user_max_iters; }
+
+#ifdef PRINT_HEAT_MAP
+ if (bound_strategy == BoundStrategy::bruteforce) {
+ fs::path fname_a_path {fname_a.c_str()};
+ fs::path fname_b_path {fname_b.c_str()};
+ fs::path fname_a_wo = fname_a_path.filename();
+ fs::path fname_b_wo = fname_b_path.filename();
+ std::string heat_map_fname = fmt::format("{0}_{1}_dim_{2}_weighted_values_xyp.txt", fname_a_wo.string(),
+ fname_b_wo.string(), params_experiment.dim);
+ fs::path path_hm = fname_a_path.replace_filename(fs::path(heat_map_fname.c_str()));
+ spd::debug("Saving heatmap to {}", heat_map_fname);
+ print_heat_map(params_experiment.heat_maps, path_hm.string(), params);
+ }
+#endif
+ spd::debug("Finished processing bound_strategy {}", bound_strategy);
+ }
+ }
+
+// std::cout << "File_1;File_2;Boundstrategy;TraverseStrategy;InitalDepth;NHeraCalls;SavingsRatio;Time;Distance;Error;PushStrategy;MaxDepth;CallsOnMaxDepth;Delta;Dimension" << std::endl;
+ for(auto bs : bound_strategies) {
+ for(auto ts : traverse_strategies) {
+ auto key = std::make_tuple(bs, ts);
+
+ fs::path fname_a_path {fname_a.c_str()};
+ fs::path fname_b_path {fname_b.c_str()};
+ fs::path fname_a_wo = fname_a_path.filename();
+ fs::path fname_b_wo = fname_b_path.filename();
+
+ std::cout << fname_a_wo.string() << ";" << fname_b_wo.string() << ";" << bs << ";" << ts << ";";
+ std::cout << results[key].params.initialization_depth << ";";
+ std::cout << results[key].n_hera_calls << ";"
+ << results[key].savings_ratio() << ";"
+ << results[key].total_milliseconds_elapsed << ";"
+ << results[key].distance << ";"
+ << results[key].actual_error << ";"
+ << "xyp" << ";"
+ << results[key].actual_max_depth << ";"
+ << results[key].calls_on_actual_max_depth() << ";"
+ << params.delta << ";"
+ << params.dim
+ << std::endl;
+ }
+ }
+#else
+ params.bound_strategy = bound_strategies.back();
+ params.traverse_strategy = traverse_strategies.back();
+
+ Real dist = matching_distance(bif_a, bif_b, params);
+ std::cout << dist << std::endl;
+#endif
+ return 0;
+}
diff --git a/matching/src/matching_distance.cpp b/matching/src/matching_distance.cpp
new file mode 100644
index 0000000..ac96ba2
--- /dev/null
+++ b/matching/src/matching_distance.cpp
@@ -0,0 +1,907 @@
+#include <chrono>
+#include <tuple>
+#include <algorithm>
+
+#include "common_defs.h"
+
+#include "spdlog/fmt/ostr.h"
+#include "matching_distance.h"
+
+namespace md {
+
+ template<class K, class V>
+ void print_map(const std::map<K, V>& dic)
+ {
+ for(const auto kv : dic) {
+ fmt::print("{} -> {}\n", kv.first, kv.second);
+ }
+ }
+
+ void DistanceCalculator::check_upper_bound(const CellWithValue& dual_cell, int dim) const
+ {
+ spd::debug("Enter check_get_max_delta_on_cell");
+ const int n_samples_lambda = 100;
+ const int n_samples_mu = 100;
+ DualBox db = dual_cell.dual_box();
+ Real min_lambda = db.lambda_min();
+ Real max_lambda = db.lambda_max();
+ Real min_mu = db.mu_min();
+ Real max_mu = db.mu_max();
+
+ Real h_lambda = (max_lambda - min_lambda) / n_samples_lambda;
+ Real h_mu = (max_mu - min_mu) / n_samples_mu;
+ for(int i = 1; i < n_samples_lambda; ++i) {
+ for(int j = 1; j < n_samples_mu; ++j) {
+ Real lambda = min_lambda + i * h_lambda;
+ Real mu = min_mu + j * h_mu;
+ DualPoint l(db.axis_type(), db.angle_type(), lambda, mu);
+ Real other_result = distance_on_line_const(dim, l);
+ Real diff = fabs(dual_cell.stored_upper_bound() - other_result);
+ if (other_result > dual_cell.stored_upper_bound()) {
+ spd::error(
+ "in check_upper_bound, upper_bound = {}, other_result = {}, diff = {}, dim = {}\ndual_cell = {}",
+ dual_cell.stored_upper_bound(), other_result, diff, dim, dual_cell);
+ throw std::runtime_error("Wrong delta estimate");
+ }
+ }
+ }
+ spd::debug("Exit check_get_max_delta_on_cell");
+ }
+
+ // for all lines l, l' inside dual box,
+ // find the upper bound on the difference of weighted pushes of p
+ Real
+ DistanceCalculator::get_max_displacement_single_point(const CellWithValue& dual_cell, ValuePoint vp,
+ const Point& p) const
+ {
+ assert(p.x >= 0 && p.y >= 0);
+
+#ifdef MD_DEBUG
+ std::vector<long long int> debug_ids = {3, 13, 54, 218, 350, 382, 484, 795, 2040, 8415, 44076};
+ bool debug = false; // std::find(debug_ids.begin(), debug_ids.end(), dual_cell.id) != debug_ids.end();
+#endif
+ DualPoint line = dual_cell.value_point(vp);
+ const Real base_value = line.weighted_push(p);
+
+ spd::debug("Enter get_max_displacement_single_point, p = {},\ndual_cell = {},\nline = {}, base_value = {}\n", p,
+ dual_cell, line, base_value);
+
+ Real result = 0.0;
+ for(DualPoint dp : dual_cell.dual_box().critical_points(p)) {
+ Real dp_value = dp.weighted_push(p);
+ spd::debug(
+ "In get_max_displacement_single_point, p = {}, critical dp = {},\ndp_value = {}, diff = {},\ndual_cell = {}\n",
+ p, dp, dp_value, fabs(base_value - dp_value), dual_cell);
+ result = std::max(result, fabs(base_value - dp_value));
+ }
+
+#ifdef MD_DO_FULL_CHECK
+ DualBox db = dual_cell.dual_box();
+ std::uniform_real_distribution<Real> dlambda(db.lambda_min(), db.lambda_max());
+ std::uniform_real_distribution<Real> dmu(db.mu_min(), db.mu_max());
+ std::mt19937 gen(1);
+ for(int i = 0; i < 1000; ++i) {
+ Real lambda = dlambda(gen);
+ Real mu = dmu(gen);
+ DualPoint dp_random { db.axis_type(), db.angle_type(), lambda, mu };
+ Real dp_value = dp_random.weighted_push(p);
+ if (fabs(base_value - dp_value) > result) {
+ spd::error("in get_max_displacement_single_point, p = {}, vp = {}\ndb = {}\nresult = {}, base_value = {}, dp_value = {}, dp_random = {}",
+ p, vp, db, result, base_value, dp_value, dp_random);
+ throw std::runtime_error("error in get_max_displacement_single_value");
+ }
+ }
+#endif
+
+ return result;
+ }
+
+ DistanceCalculator::CellValueVector DistanceCalculator::get_initial_dual_grid(Real& lower_bound)
+ {
+ CellValueVector result = get_refined_grid(params_.initialization_depth, false, true);
+
+ lower_bound = -1.0;
+ for(const auto& dc : result) {
+ lower_bound = std::max(lower_bound, dc.max_corner_value());
+ }
+
+ assert(lower_bound >= 0);
+
+ for(auto& dual_cell : result) {
+ Real good_enough_ub = get_good_enough_upper_bound(lower_bound);
+ Real max_value_on_cell = get_upper_bound(dual_cell, params_.dim, good_enough_ub);
+ dual_cell.set_max_possible_value(max_value_on_cell);
+
+#ifdef MD_DO_FULL_CHECK
+ check_upper_bound(dual_cell, params_.dim);
+#endif
+
+ spd::debug("DEBUG INIT: added cell {}", dual_cell);
+ }
+
+
+
+ return result;
+ }
+
+ DistanceCalculator::CellValueVector
+ DistanceCalculator::get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last)
+ {
+ const Real y_max = std::max(module_a_.max_y(), module_b_.max_y());
+ const Real x_max = std::max(module_a_.max_x(), module_b_.max_x());
+
+ const Real lambda_min = 0;
+ const Real lambda_max = 1;
+
+ const Real mu_min = 0;
+
+ DualBox x_flat(DualPoint(AxisType::x_type, AngleType::flat, lambda_min, mu_min),
+ DualPoint(AxisType::x_type, AngleType::flat, lambda_max, x_max));
+
+ DualBox x_steep(DualPoint(AxisType::x_type, AngleType::steep, lambda_min, mu_min),
+ DualPoint(AxisType::x_type, AngleType::steep, lambda_max, x_max));
+
+ DualBox y_flat(DualPoint(AxisType::y_type, AngleType::flat, lambda_min, mu_min),
+ DualPoint(AxisType::y_type, AngleType::flat, lambda_max, y_max));
+
+ DualBox y_steep(DualPoint(AxisType::y_type, AngleType::steep, lambda_min, mu_min),
+ DualPoint(AxisType::y_type, AngleType::steep, lambda_max, y_max));
+
+ CellWithValue x_flat_cell(x_flat, 0);
+ CellWithValue x_steep_cell(x_steep, 0);
+ CellWithValue y_flat_cell(y_flat, 0);
+ CellWithValue y_steep_cell(y_steep, 0);
+
+ if (init_depth == 0) {
+ DualPoint diagonal_x_flat(AxisType::x_type, AngleType::flat, 1, 0);
+
+ Real diagonal_value = distance_on_line(params_.dim, diagonal_x_flat);
+ n_hera_calls_per_level_[0]++;
+
+ x_flat_cell.set_value_at(ValuePoint::lower_right, diagonal_value);
+ y_flat_cell.set_value_at(ValuePoint::lower_right, diagonal_value);
+ x_steep_cell.set_value_at(ValuePoint::lower_right, diagonal_value);
+ y_steep_cell.set_value_at(ValuePoint::lower_right, diagonal_value);
+ }
+
+#ifdef MD_DEBUG
+ x_flat_cell.id = 1;
+ x_steep_cell.id = 2;
+ y_flat_cell.id = 3;
+ y_steep_cell.id = 4;
+ CellWithValue::max_id = 4;
+#endif
+
+ CellValueVector result {x_flat_cell, x_steep_cell, y_flat_cell, y_steep_cell};
+
+ if (init_depth == 0) {
+ return result;
+ }
+
+ CellValueVector refined_result;
+
+ for(int i = 1; i <= init_depth; ++i) {
+ refined_result.clear();
+ for(const auto& dual_cell : result) {
+ for(auto refined_cell : dual_cell.get_refined_cells()) {
+ // we calculate for init_dept - 1, not init_depth,
+ // because we want the cells to have value at a corner
+ if ((i == init_depth - 1 and calculate_on_last) or calculate_on_intermediate)
+ set_cell_central_value(refined_cell, params_.dim);
+ refined_result.push_back(refined_cell);
+ }
+ }
+ result = std::move(refined_result);
+ }
+ return result;
+ }
+
+ DistanceCalculator::DistanceCalculator(const DiagramProvider& a,
+ const DiagramProvider& b,
+ CalculationParams& params)
+ :
+ module_a_(a),
+ module_b_(b),
+ params_(params),
+ maximal_dim_(std::max(a.maximal_dim(), b.maximal_dim())),
+ distances_(1 + std::max(a.maximal_dim(), b.maximal_dim()), Real(-1))
+ {
+ // make all coordinates non-negative
+ auto min_coord = std::min(module_a_.minimal_coordinate(),
+ module_b_.minimal_coordinate());
+ if (min_coord < 0) {
+ module_a_.translate(-min_coord);
+ module_b_.translate(-min_coord);
+ }
+
+ assert(std::min({module_a_.min_x(), module_b_.min_x(), module_a_.min_y(),
+ module_b_.min_y()}) >= 0);
+
+ spd::info("DistanceCalculator constructed, module_a: max_x = {}, max_y = {}, module_b: max_x = {}, max_y = {}",
+ module_a_.max_x(), module_a_.max_y(), module_b_.max_x(), module_b_.max_y());
+ }
+
+ void DistanceCalculator::clear_cache()
+ {
+ distances_ = std::vector<Real>(maximal_dim_, Real(-1));
+ }
+
+ Real DistanceCalculator::get_max_x(int module) const
+ {
+ return (module == 0) ? module_a_.max_x() : module_b_.max_x();
+ }
+
+ Real DistanceCalculator::get_max_y(int module) const
+ {
+ return (module == 0) ? module_a_.max_y() : module_b_.max_y();
+ }
+
+ Real
+ DistanceCalculator::get_local_refined_bound(const md::DualBox& dual_box) const
+ {
+ return get_local_refined_bound(0, dual_box) + get_local_refined_bound(1, dual_box);
+ }
+
+ Real
+ DistanceCalculator::get_local_refined_bound(int module, const md::DualBox& dual_box) const
+ {
+ spd::debug("Enter get_local_refined_bound, dual_box = {}", dual_box);
+ Real d_lambda = dual_box.lambda_max() - dual_box.lambda_min();
+ Real d_mu = dual_box.mu_max() - dual_box.mu_min();
+ Real result;
+ if (dual_box.axis_type() == AxisType::x_type) {
+ if (dual_box.is_flat()) {
+ result = dual_box.lambda_max() * d_mu + (get_max_x(module) - dual_box.mu_min()) * d_lambda;
+ } else {
+ result = d_mu + get_max_y(module) * d_lambda;
+ }
+ } else {
+ // y-type
+ if (dual_box.is_flat()) {
+ result = d_mu + get_max_x(module) * d_lambda;
+ } else {
+ // steep
+ result = dual_box.lambda_max() * d_mu + (get_max_y(module) - dual_box.mu_min()) * d_lambda;
+ }
+ }
+ return result;
+ }
+
+ Real DistanceCalculator::get_local_dual_bound(int module, const md::DualBox& dual_box) const
+ {
+ Real dlambda = dual_box.lambda_max() - dual_box.lambda_min();
+ Real dmu = dual_box.mu_max() - dual_box.mu_min();
+ Real C = std::max(get_max_x(module), get_max_y(module));
+
+ //return 2 * (C * dlambda + dmu);
+
+ // additional factor of 2 because we mimic Cerri's paper
+ // where subdivision is on angle spaces,
+ // and tangent/cotangent is 2-Lipschitz
+ if (dual_box.is_flat()) {
+ return get_max_x(module) * dlambda + dmu;
+ } else {
+ return get_max_y(module) * dlambda + dmu;
+ }
+ }
+
+ Real DistanceCalculator::get_local_dual_bound(const md::DualBox& dual_box) const
+ {
+ return get_local_dual_bound(0, dual_box) + get_local_dual_bound(1, dual_box);
+ }
+
+ Real DistanceCalculator::get_upper_bound(const CellWithValue& dual_cell, int dim, Real good_enough_ub) const
+ {
+ assert(good_enough_ub >= 0);
+
+ switch(params_.bound_strategy) {
+ case BoundStrategy::bruteforce:
+ return std::numeric_limits<Real>::max();
+
+ case BoundStrategy::local_dual_bound:
+ return dual_cell.min_value() + get_local_dual_bound(dual_cell.dual_box());
+
+ case BoundStrategy::local_dual_bound_refined:
+ return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
+
+ case BoundStrategy::local_combined: {
+ Real cheap_upper_bound = dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
+ if (cheap_upper_bound < good_enough_ub) {
+ return cheap_upper_bound;
+ } else {
+ [[fallthrough]];
+ }
+ }
+
+ case BoundStrategy::local_dual_bound_for_each_point: {
+ Real result = std::numeric_limits<Real>::max();
+ for(ValuePoint vp : k_corner_vps) {
+ if (not dual_cell.has_value_at(vp)) {
+ continue;
+ }
+
+ Real base_value = dual_cell.value_at(vp);
+ Real bound_dgm_a = get_single_dgm_bound(dual_cell, vp, 0, dim, good_enough_ub);
+
+ if (params_.stop_asap and bound_dgm_a + base_value >= good_enough_ub) {
+ // we want to return a valid upper bound, not just something that will prevent discarding the cell
+ // and we don't want to compute pushes for points in second bifiltration.
+ // so just return a constant time bound
+ return dual_cell.min_value() + get_local_refined_bound(dual_cell.dual_box());
+ }
+
+ Real bound_dgm_b = get_single_dgm_bound(dual_cell, vp, 1, dim,
+ std::max(Real(0), good_enough_ub - bound_dgm_a));
+
+ result = std::min(result, base_value + bound_dgm_a + bound_dgm_b);
+
+#ifdef MD_DEBUG
+ spd::debug("In get_upper_bound, cell = {}", dual_cell);
+ spd::debug("In get_upper_bound, vp = {}, base_value = {}, bound_dgm_a = {}, bound_dgm_b = {}, result = {}", vp, base_value, bound_dgm_a, bound_dgm_b, result);
+#endif
+
+ if (params_.stop_asap and result < good_enough_ub) {
+ break;
+ }
+ }
+ return result;
+ }
+ }
+ // to suppress compiler warning
+ return std::numeric_limits<Real>::max();
+ }
+
+ // find maximal displacement of weighted points of m for all lines in dual_box
+ Real
+ DistanceCalculator::get_single_dgm_bound(const CellWithValue& dual_cell,
+ ValuePoint vp,
+ int module,
+ int dim,
+ [[maybe_unused]] Real good_enough_value) const
+ {
+ Real result = 0;
+ Point max_point;
+
+ spd::debug("Enter get_single_dgm_bound, module = {}, dual_cell = {}, vp = {}, good_enough_value = {}, stop_asap = {}\n", module, dual_cell, vp, good_enough_value, params_.stop_asap);
+
+ const DiagramProvider& m = (module == 0) ? module_a_ : module_b_;
+ for(const auto& simplex : m.simplices()) {
+ spd::debug("in get_single_dgm_bound, simplex = {}\n", simplex);
+ if (dim != simplex.dim() and dim + 1 != simplex.dim())
+ continue;
+
+ Real x = get_max_displacement_single_point(dual_cell, vp, simplex.position());
+
+ spd::debug("In get_single_dgm_bound, point = {}, displacement = {}", simplex.position(), x);
+
+ if (x > result) {
+ result = x;
+ max_point = simplex.position();
+ spd::debug("In get_single_dgm_bound, point = {}, result now = displacement = {}", simplex.position(), x);
+ }
+
+ if (params_.stop_asap and result > good_enough_value) {
+ // we want to return a valid upper bound,
+ // now we just see it is worse than we need, but it may be even more
+ // just return a valid upper bound
+ spd::debug("result {} > good_enough_value {}, exit and return refined bound {}", result, good_enough_value, get_local_refined_bound(dual_cell.dual_box()));
+ result = get_local_refined_bound(dual_cell.dual_box());
+ break;
+ }
+ }
+
+ spd::debug("Exit get_single_dgm_bound,\ndual_cell = {}\nmodule = {}, dim = {}, result = {}, max_point = {}", dual_cell, module, dim, result, max_point);
+
+ return result;
+ }
+
+ Real DistanceCalculator::distance()
+ {
+ if (params_.dim != CalculationParams::ALL_DIMENSIONS) {
+ return distance_in_dimension_pq(params_.dim);
+ } else {
+ Real result = -1.0;
+ for(int d = 0; d <= maximal_dim_; ++d) {
+ result = std::max(result, distance_in_dimension_pq(d));
+ }
+ return result;
+ }
+ }
+
+ // calculate weighted bottleneneck distance between slices on line
+ // in dimension dim
+ // increments hera calls counter
+ Real DistanceCalculator::distance_on_line(int dim, DualPoint line)
+ {
+ // order matters - distance_on_line_const assumes n_hera_calls_ map has entry for dim
+ ++n_hera_calls_[dim];
+ Real result = distance_on_line_const(dim, line);
+ return result;
+ }
+
+ Real DistanceCalculator::distance_on_line_const(int dim, DualPoint line) const
+ {
+ // TODO: think about this - how to call Hera
+ Real hera_epsilon = 0.001;
+ auto dgm_a = module_a_.weighted_slice_diagram(line, dim).get_diagram(dim);
+ auto dgm_b = module_b_.weighted_slice_diagram(line, dim).get_diagram(dim);
+// Real result = hera::bottleneckDistApprox(dgm_a, dgm_b, hera_epsilon);
+ Real result = hera::bottleneckDistExact(dgm_a, dgm_b);
+ if (n_hera_calls_.at(dim) % 100 == 1) {
+ spd::debug("Calling Hera, dgm_a.size = {}, dgm_b.size = {}, line = {}, result = {}", dgm_a.size(), dgm_b.size(), line, result);
+ } else {
+ spd::debug("Calling Hera, dgm_a.size = {}, dgm_b.size = {}, line = {}, result = {}", dgm_a.size(), dgm_b.size(), line, result);
+ }
+ return result;
+ }
+
+ Real DistanceCalculator::get_good_enough_upper_bound(Real lower_bound) const
+ {
+ Real result;
+ // in upper_bound strategy we only prune cells if they cannot improve the lower bound,
+ // otherwise the experiment is supposed to run indefinitely
+ if (params_.traverse_strategy == TraverseStrategy::upper_bound) {
+ result = lower_bound;
+ } else {
+ result = (1.0 + params_.delta) * lower_bound;
+ }
+ return result;
+ }
+
+ // helper function
+ // calculate weighted bt distance in dim on cell center,
+ // assign distance value to cell, keep it in heat_map, and return
+ void DistanceCalculator::set_cell_central_value(CellWithValue& dual_cell, int dim)
+ {
+ DualPoint central_line {dual_cell.center()};
+
+ spd::debug("In set_cell_central_value, processing dual cell = {}, line = {}", dual_cell.dual_box(),
+ central_line);
+ Real new_value = distance_on_line(dim, central_line);
+ n_hera_calls_per_level_[dual_cell.level() + 1]++;
+ dual_cell.set_value_at(ValuePoint::center, new_value);
+ params_.actual_max_depth = std::max(params_.actual_max_depth, dual_cell.level() + 1);
+
+#ifdef PRINT_HEAT_MAP
+ if (params_.bound_strategy == BoundStrategy::bruteforce) {
+ spd::debug("In set_cell_central_value, adding to heat_map pair {} - {}", dual_cell.center(), new_value);
+ if (dual_cell.level() > params_.initialization_depth + 1
+ and params_.heat_maps[dual_cell.level()].count(dual_cell.center()) > 0) {
+ auto existing = params_.heat_maps[dual_cell.level()].find(dual_cell.center());
+ spd::debug("EXISTING: {} -> {}", existing->first, existing->second);
+ }
+ assert(dual_cell.level() <= params_.initialization_depth + 1
+ or params_.heat_maps[dual_cell.level()].count(dual_cell.center()) == 0);
+ params_.heat_maps[dual_cell.level()][dual_cell.center()] = new_value;
+ }
+#endif
+ }
+
+ // quick-and-dirty hack to efficiently traverse priority queue with dual cells
+ // returns maximal possible value on all cells in queue
+ // assumes that the underlying container is vector!
+ // cell_ptr: pointer to the first element in queue
+ // n_cells: queue size
+ Real DistanceCalculator::get_max_possible_value(const CellWithValue* cell_ptr, int n_cells)
+ {
+ Real result = (n_cells > 0) ? cell_ptr->stored_upper_bound() : 0;
+ for(int i = 0; i < n_cells; ++i, ++cell_ptr) {
+ result = std::max(result, cell_ptr->stored_upper_bound());
+ }
+ return result;
+ }
+
+ // helper function:
+ // return current error from lower and upper bounds
+ // and save it in params_ (hence not const)
+ Real DistanceCalculator::current_error(Real lower_bound, Real upper_bound)
+ {
+ Real current_error = (lower_bound > 0.0) ? (upper_bound - lower_bound) / lower_bound
+ : std::numeric_limits<Real>::max();
+
+ params_.actual_error = current_error;
+
+ if (current_error < params_.delta) {
+ spd::debug(
+ "Threshold achieved! bound_strategy = {}, traverse_strategy = {}, upper_bound = {}, current_error = {}",
+ params_.bound_strategy, params_.traverse_strategy, upper_bound, current_error);
+ }
+ return current_error;
+ }
+
+ struct UbExperimentRecord {
+ Real error;
+ Real lower_bound;
+ Real upper_bound;
+ CellWithValue cell;
+ long long int time;
+ long long int n_hera_calls;
+ };
+
+ std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r);
+
+ // return matching distance in dimension dim
+ // use priority queue to store dual cells
+ // comparison function depends on the strategies in params_
+ // ressets hera calls counter
+ Real DistanceCalculator::distance_in_dimension_pq(int dim)
+ {
+ std::map<int, long> n_cells_considered;
+ std::map<int, long> n_cells_pushed_into_queue;
+ long int n_too_deep_cells = 0;
+ std::map<int, long> n_cells_discarded;
+ std::map<int, long> n_cells_pruned;
+
+ spd::info("Enter distance_in_dimension_pq, dim = {}, bound strategy = {}, traverse strategy = {}, stop_asap = {} ", dim, params_.bound_strategy, params_.traverse_strategy, params_.stop_asap);
+
+ std::chrono::high_resolution_clock timer;
+ auto start_time = timer.now();
+
+ n_hera_calls_[dim] = 0;
+ n_hera_calls_per_level_.clear();
+
+
+ // if cell is too deep and is not pushed into queue,
+ // we still need to take its max value into account;
+ // the max over such cells is stored in max_result_on_too_fine_cells
+ Real upper_bound_on_deep_cells = -1;
+
+ spd::debug("Started iterations in dual space, delta = {}, bound_strategy = {}", params_.delta, params_.bound_strategy);
+ // user-defined less lambda function
+ // to regulate priority queue depending on strategy
+ auto dual_cell_less = [this](const CellWithValue& a, const CellWithValue& b) {
+
+ int a_level = a.level();
+ int b_level = b.level();
+ Real a_value = a.max_corner_value();
+ Real b_value = b.max_corner_value();
+ Real a_ub = a.stored_upper_bound();
+ Real b_ub = b.stored_upper_bound();
+ if (this->params_.traverse_strategy == TraverseStrategy::upper_bound and
+ (not a.has_max_possible_value() or not b.has_max_possible_value())) {
+ throw std::runtime_error("no upper bound on cell");
+ }
+ DualPoint a_lower_left = a.dual_box().lower_left();
+ DualPoint b_lower_left = b.dual_box().lower_left();
+
+ switch(this->params_.traverse_strategy) {
+ // in both breadth_first searches we want coarser cells
+ // to be processed first. Cells with smaller level must be larger,
+ // hence the minus in front of level
+ case TraverseStrategy::breadth_first:
+ return std::make_tuple(-a_level, a_lower_left)
+ < std::make_tuple(-b_level, b_lower_left);
+ case TraverseStrategy::breadth_first_value:
+ return std::make_tuple(-a_level, a_value, a_lower_left)
+ < std::make_tuple(-b_level, b_value, b_lower_left);
+ case TraverseStrategy::depth_first:
+ return std::make_tuple(a_value, a_level, a_lower_left)
+ < std::make_tuple(b_value, b_level, b_lower_left);
+ case TraverseStrategy::upper_bound:
+ return std::make_tuple(a_ub, a_level, a_lower_left)
+ < std::make_tuple(b_ub, b_level, b_lower_left);
+ default:
+ throw std::runtime_error("Forgotten case");
+ }
+ };
+
+ std::priority_queue<CellWithValue, CellValueVector, decltype(dual_cell_less)> dual_cells_queue(
+ dual_cell_less);
+
+ // weighted bt distance on the center of current cell
+ Real lower_bound = std::numeric_limits<Real>::min();
+
+ // init pq and lower bound
+ for(auto& init_cell : get_initial_dual_grid(lower_bound)) {
+ dual_cells_queue.push(init_cell);
+ }
+
+ Real upper_bound = get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size());
+
+ std::vector<UbExperimentRecord> ub_experiment_results;
+
+ while(not dual_cells_queue.empty()) {
+
+ CellWithValue dual_cell = dual_cells_queue.top();
+ dual_cells_queue.pop();
+ assert(dual_cell.has_corner_value()
+ and dual_cell.has_max_possible_value()
+ and dual_cell.max_corner_value() <= upper_bound);
+
+ n_cells_considered[dual_cell.level()]++;
+
+ bool discard_cell = false;
+
+ if (not params_.stop_asap) {
+ // if stop_asap is on, it is safer to never discard a cell
+ if (params_.bound_strategy == BoundStrategy::bruteforce) {
+ discard_cell = false;
+ } else if (params_.traverse_strategy == TraverseStrategy::upper_bound) {
+ discard_cell = (dual_cell.stored_upper_bound() <= lower_bound);
+ } else {
+ discard_cell = (dual_cell.stored_upper_bound() <= (1.0 + params_.delta) * lower_bound);
+ }
+ }
+
+ spd::debug("CURRENT CELL bound_strategy = {}, traverse_strategy = {}, dual cell: {}, upper_bound = {}, lower_bound = {}, current_error = {}, discard_cell = {}",
+ params_.bound_strategy, params_.traverse_strategy, dual_cell, upper_bound, lower_bound, current_error(lower_bound, upper_bound), discard_cell);
+
+ if (discard_cell) {
+ n_cells_discarded[dual_cell.level()]++;
+ continue;
+ }
+
+ // until now, dual_cell knows its value in one of its corners
+ // new_value will be the weighted distance at its center
+ set_cell_central_value(dual_cell, dim);
+ Real new_value = dual_cell.value_at(ValuePoint::center);
+ lower_bound = std::max(new_value, lower_bound);
+
+ spd::debug("Processed cell = {}, weighted value = {}, lower_bound = {}", dual_cell, new_value, lower_bound);
+
+ assert(upper_bound >= lower_bound);
+
+ if (current_error(lower_bound, upper_bound) < params_.delta) {
+ break;
+ }
+
+ // refine cell and push 4 smaller cells into queue
+ for(auto refined_cell : dual_cell.get_refined_cells()) {
+
+ if (refined_cell.num_values() == 0)
+ throw std::runtime_error("no value on cell");
+
+ // if delta is smaller than good_enough_value, it allows to prune cell
+ Real good_enough_ub = get_good_enough_upper_bound(lower_bound);
+
+ // upper bound of the parent holds for refined_cell
+ // and can sometimes be smaller!
+ Real upper_bound_on_refined_cell = std::min(dual_cell.stored_upper_bound(),
+ get_upper_bound(refined_cell, dim, good_enough_ub));
+
+ spd::debug("upper_bound_on_refined_cell = {}, dual_cell.stored_upper_bound = {}, get_upper_bound = {}",
+ upper_bound_on_refined_cell, dual_cell.stored_upper_bound(), get_upper_bound(refined_cell, dim, good_enough_ub));
+
+ refined_cell.set_max_possible_value(upper_bound_on_refined_cell);
+
+#ifdef MD_DO_FULL_CHECK
+ check_upper_bound(refined_cell, dim);
+#endif
+
+ bool prune_cell = false;
+
+ if (refined_cell.level() <= params_.max_depth) {
+ // cell might be added to queue; if it is not added, its maximal value can be safely ignored
+ if (params_.traverse_strategy == TraverseStrategy::upper_bound) {
+ prune_cell = (refined_cell.stored_upper_bound() <= lower_bound);
+ } else if (params_.bound_strategy != BoundStrategy::bruteforce) {
+ prune_cell = (refined_cell.stored_upper_bound() <= (1.0 + params_.delta) * lower_bound);
+ }
+ if (prune_cell)
+ n_cells_pruned[refined_cell.level()]++;
+// prune_cell = (max_result_on_refined_cell <= lower_bound);
+ } else {
+ // cell is too deep, it won't be added to queue
+ // we must memorize maximal value on this cell, because we won't see it anymore
+ prune_cell = true;
+ if (refined_cell.stored_upper_bound() > (1 + params_.delta) * lower_bound) {
+ n_too_deep_cells++;
+ }
+ upper_bound_on_deep_cells = std::max(upper_bound_on_deep_cells, refined_cell.stored_upper_bound());
+ }
+
+ spd::debug("In distance_in_dimension_pq, loop over refined cells, bound_strategy = {}, traverse_strategy = {}, refined cell: {}, max_value_on_cell = {}, upper_bound = {}, current_error = {}, prune_cell = {}",
+ params_.bound_strategy, params_.traverse_strategy, refined_cell, refined_cell.stored_upper_bound(), upper_bound, current_error(lower_bound, upper_bound), prune_cell);
+
+ if (not prune_cell) {
+ n_cells_pushed_into_queue[refined_cell.level()]++;
+ dual_cells_queue.push(refined_cell);
+ }
+ } // end loop over refined cells
+
+ if (dual_cells_queue.empty())
+ upper_bound = std::max(upper_bound, upper_bound_on_deep_cells);
+ else
+ upper_bound = std::max(upper_bound_on_deep_cells,
+ get_max_possible_value(&dual_cells_queue.top(), dual_cells_queue.size()));
+
+ if (params_.traverse_strategy == TraverseStrategy::upper_bound) {
+ upper_bound = dual_cells_queue.top().stored_upper_bound();
+
+ if (get_hera_calls_number(params_.dim) < 20 || get_hera_calls_number(params_.dim) % 20 == 0) {
+ auto elapsed = timer.now() - start_time;
+ UbExperimentRecord ub_exp_record;
+
+ ub_exp_record.error = current_error(lower_bound, upper_bound);
+ ub_exp_record.lower_bound = lower_bound;
+ ub_exp_record.upper_bound = upper_bound;
+ ub_exp_record.cell = dual_cells_queue.top();
+ ub_exp_record.n_hera_calls = n_hera_calls_[dim];
+ ub_exp_record.time = std::chrono::duration_cast<std::chrono::milliseconds>(elapsed).count();
+
+#ifdef MD_DO_CHECKS
+ if (ub_experiment_results.size() > 0) {
+ auto prev = ub_experiment_results.back();
+ if (upper_bound > prev.upper_bound) {
+ spd::error("ALARM 1, upper_bound = {}, top = {}, prev.ub = {}, prev cell = {}, lower_bound = {}, prev.lower_bound = {}",
+ upper_bound, ub_exp_record.cell, prev.upper_bound, prev.cell, lower_bound, prev.lower_bound);
+ throw std::runtime_error("die");
+ }
+
+ if (lower_bound < prev.lower_bound) {
+ spd::error("ALARM 2, lower_bound = {}, prev.lower_bound = {}, top = {}, prev.ub = {}, prev cell = {}", lower_bound, prev.lower_bound, ub_exp_record.cell, prev.upper_bound, prev.cell);
+ throw std::runtime_error("die");
+ }
+ }
+#endif
+
+ ub_experiment_results.emplace_back(ub_exp_record);
+
+ fmt::print(std::cerr, "[UB_EXPERIMENT]\t{}\n", ub_exp_record);
+ }
+ }
+
+ assert(upper_bound >= lower_bound);
+
+ if (current_error(lower_bound, upper_bound) < params_.delta) {
+ break;
+ }
+ }
+
+ params_.actual_error = current_error(lower_bound, upper_bound);
+
+ if (n_too_deep_cells > 0) {
+ spd::warn("Error not guaranteed, there were {} too deep cells. Actual error = {}. Increase max_depth or delta", n_too_deep_cells, params_.actual_error);
+ }
+ // otherwise actual_error in params can be larger than delta,
+ // but this is OK
+
+ spd::info("#############################################################");
+ spd::info("Exiting distance_in_dimension_pq, bound_strategy = {}, traverse_strategy = {}, lower_bound = {}, upper_bound = {}, current_error = {}, actual_max_level = {}",
+ params_.bound_strategy, params_.traverse_strategy, lower_bound,
+ upper_bound, params_.actual_error, params_.actual_max_depth);
+
+ spd::info("#############################################################");
+
+ bool print_stats = true;
+ if (print_stats) {
+ fmt::print("EXIT STATS, cells considered:\n");
+ print_map(n_cells_considered);
+ fmt::print("EXIT STATS, cells discarded:\n");
+ print_map(n_cells_discarded);
+ fmt::print("EXIT STATS, cells pruned:\n");
+ print_map(n_cells_pruned);
+ fmt::print("EXIT STATS, cells pushed:\n");
+ print_map(n_cells_pushed_into_queue);
+ fmt::print("EXIT STATS, hera calls:\n");
+ print_map(n_hera_calls_per_level_);
+
+ fmt::print("EXIT STATS, too deep cells with high value: {}\n", n_too_deep_cells);
+ }
+
+ return lower_bound;
+ }
+
+ int DistanceCalculator::get_hera_calls_number(int dim) const
+ {
+ if (dim == CalculationParams::ALL_DIMENSIONS)
+ return std::accumulate(n_hera_calls_.begin(), n_hera_calls_.end(), 0,
+ [](auto x, auto y) { return x + y.second; });
+ else
+ return n_hera_calls_.at(dim);
+ }
+
+ Real matching_distance(const Bifiltration& bif_a, const Bifiltration& bif_b,
+ CalculationParams& params)
+ {
+ DistanceCalculator runner(bif_a, bif_b, params);
+ Real result = runner.distance();
+ params.n_hera_calls = runner.get_hera_calls_number(params.dim);
+ return result;
+ }
+
+ std::istream& operator>>(std::istream& is, BoundStrategy& s)
+ {
+ std::string ss;
+ is >> ss;
+ if (ss == "bruteforce") {
+ s = BoundStrategy::bruteforce;
+ } else if (ss == "local_grob") {
+ s = BoundStrategy::local_dual_bound;
+ } else if (ss == "local_combined") {
+ s = BoundStrategy::local_combined;
+ } else if (ss == "local_refined") {
+ s = BoundStrategy::local_dual_bound_refined;
+ } else if (ss == "local_for_each_point") {
+ s = BoundStrategy::local_dual_bound_for_each_point;
+ } else {
+ throw std::runtime_error("UNKNOWN BOUND STRATEGY");
+ }
+ return is;
+ }
+
+ BoundStrategy bs_from_string(std::string s)
+ {
+ std::stringstream ss(s);
+ BoundStrategy result;
+ ss >> result;
+ return result;
+ }
+
+ TraverseStrategy ts_from_string(std::string s)
+ {
+ std::stringstream ss(s);
+ TraverseStrategy result;
+ ss >> result;
+ return result;
+ }
+
+ std::istream& operator>>(std::istream& is, TraverseStrategy& s)
+ {
+ std::string ss;
+ is >> ss;
+ if (ss == "DFS") {
+ s = TraverseStrategy::depth_first;
+ } else if (ss == "BFS") {
+ s = TraverseStrategy::breadth_first;
+ } else if (ss == "BFS-VAL") {
+ s = TraverseStrategy::breadth_first_value;
+ } else if (ss == "UB") {
+ s = TraverseStrategy::upper_bound;
+ } else {
+ throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY");
+ }
+ return is;
+ }
+
+ std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r)
+ {
+ os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound;
+ return os;
+ }
+
+ std::ostream& operator<<(std::ostream& os, const BoundStrategy& s)
+ {
+ switch(s) {
+ case BoundStrategy::bruteforce :
+ os << "bruteforce";
+ break;
+ case BoundStrategy::local_dual_bound :
+ os << "local_grob";
+ break;
+ case BoundStrategy::local_combined :
+ os << "local_combined";
+ break;
+ case BoundStrategy::local_dual_bound_refined :
+ os << "local_refined";
+ break;
+ case BoundStrategy::local_dual_bound_for_each_point :
+ os << "local_for_each_point";
+ break;
+ default:
+ os << "FORGOTTEN BOUND STRATEGY";
+ }
+ return os;
+ }
+
+ std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s)
+ {
+ switch(s) {
+ case TraverseStrategy::depth_first :
+ os << "DFS";
+ break;
+ case TraverseStrategy::breadth_first :
+ os << "BFS";
+ break;
+ case TraverseStrategy::breadth_first_value :
+ os << "BFS-VAL";
+ break;
+ case TraverseStrategy::upper_bound :
+ os << "UB";
+ break;
+ default:
+ os << "FORGOTTEN TRAVERSE STRATEGY";
+ }
+ return os;
+ }
+}
diff --git a/matching/src/persistence_module.cpp b/matching/src/persistence_module.cpp
new file mode 100644
index 0000000..e947925
--- /dev/null
+++ b/matching/src/persistence_module.cpp
@@ -0,0 +1,104 @@
+#include<phat/boundary_matrix.h>
+#include<phat/compute_persistence_pairs.h>
+
+#include "persistence_module.h"
+
+namespace md {
+ PersistenceModule::PersistenceModule(const std::string& /*fname*/) // read from file
+ :
+ generators_(),
+ relations_()
+ {
+ }
+
+ Diagram PersistenceModule::slice_diagram(const DualPoint& /*line*/)
+ {
+ //Vector2D b_of_line(L.b, -L.b);
+ //for (int i = 0; i<(int) F.size(); i++) {
+ // Simplex_in_2D_filtration& curr_simplex = F[i];
+ // Vector2D proj = push(curr_simplex.pos, L);
+
+ // curr_simplex.v = L_2(proj - b_of_line);
+ // //std::cout << proj << std::endl;
+ // //std::cout << "v=" << curr_simplex.v << std::endl;
+ //}
+ //std::sort(F.begin(), F.end(), sort_functor);
+ //std::map<Index, Index> index_map;
+ //for (Index i = 0; i<(int) F.size(); i++) {
+ // index_map[F[i].index] = i;
+ // //std::cout << F[i].index << " -> " << i << std::endl;
+ //}
+ //phat::boundary_matrix<> phat_matrix;
+ //phat_matrix.set_num_cols(F.size());
+ //std::vector<Index> bd_in_slice_filtration;
+ //for (Index i = 0; i<(int) F.size(); i++) {
+ // phat_matrix.set_dim(i, F[i].dim);
+ // bd_in_slice_filtration.clear();
+ // //std::cout << "new col" << i << std::endl;
+ // for (int j = 0; j<(int) F[i].bd.size(); j++) {
+ // // F[i] contains the indices of its facet wrt to the
+ // // original filtration. We have to express it, however,
+ // // wrt to the filtration along the slice. That is why
+ // // we need the index_map
+ // //std::cout << "Found " << F[i].bd[j] << ", returning " << index_map[F[i].bd[j]] << std::endl;
+ // bd_in_slice_filtration.push_back(index_map[F[i].bd[j]]);
+ // }
+ // std::sort(bd_in_slice_filtration.begin(), bd_in_slice_filtration.end());
+ // [>
+ // for(int j=0;j<bd_in_slice_filtration.size();j++) {
+ // std::cout << bd_in_slice_filtration[j] << " ";
+ // }
+ // std::cout << std::endl;
+ // */
+ // phat_matrix.set_col(i, bd_in_slice_filtration);
+
+ //}
+ //phat::persistence_pairs phat_persistence_pairs;
+ //phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix);
+
+ //dgm.clear();
+ //for (long i = 0; i<(long) phat_persistence_pairs.get_num_pairs(); i++) {
+ // std::pair<phat::index, phat::index> new_pair = phat_persistence_pairs.get_pair(i);
+ // double birth = F[new_pair.first].v;
+ // double death = F[new_pair.second].v;
+ // if (birth!=death) {
+ // dgm.push_back(std::make_pair(birth, death));
+ // }
+ //}
+
+ //[>
+ //std::cout << "Done, created diagram: " << std::endl;
+ //for(int i=0;i<(int)dgm.size();i++) {
+ // std::cout << dgm[i].first << " " << dgm[i].second << std::endl;
+ //}
+ //*/
+ return Diagram();
+
+ }
+
+ PersistenceModule::Box PersistenceModule::bounding_box() const
+ {
+ Real ll_x = std::numeric_limits<Real>::max();
+ Real ll_y = std::numeric_limits<Real>::max();
+ Real ur_x = -std::numeric_limits<Real>::max();
+ Real ur_y = -std::numeric_limits<Real>::max();
+
+ for(const auto& gen : generators_) {
+ ll_x = std::min(gen.x, ll_x);
+ ll_y = std::min(gen.y, ll_y);
+ ur_x = std::max(gen.x, ur_x);
+ ur_y = std::max(gen.y, ur_y);
+ }
+
+ for(const auto& rel : relations_) {
+
+ ll_x = std::min(rel.get_x(), ll_x);
+ ll_y = std::min(rel.get_y(), ll_y);
+
+ ur_x = std::max(rel.get_x(), ur_x);
+ ur_y = std::max(rel.get_y(), ur_y);
+ }
+
+ return Box(Point(ll_x, ll_y), Point(ur_x, ur_y));
+ }
+}
diff --git a/matching/src/simplex.cpp b/matching/src/simplex.cpp
new file mode 100644
index 0000000..c5cdd25
--- /dev/null
+++ b/matching/src/simplex.cpp
@@ -0,0 +1,123 @@
+#include "simplex.h"
+
+namespace md {
+
+ std::ostream& operator<<(std::ostream& os, const AbstractSimplex& s)
+ {
+ os << "AbstractSimplex(id = " << s.id << ", vertices_ = " << container_to_string(s.vertices_) << ")";
+ return os;
+ }
+
+ bool operator<(const AbstractSimplex& a, const AbstractSimplex& b)
+ {
+ return a.vertices_ < b.vertices_;
+ }
+
+ bool operator==(const AbstractSimplex& s1, const AbstractSimplex& s2)
+ {
+ return s1.vertices_ == s2.vertices_;
+ }
+
+ void AbstractSimplex::push_back(int v)
+ {
+ vertices_.push_back(v);
+ std::sort(vertices_.begin(), vertices_.end());
+ }
+
+ AbstractSimplex::AbstractSimplex(std::vector<int> vertices, bool sort)
+ :vertices_(vertices)
+ {
+ if (sort)
+ std::sort(vertices_.begin(), vertices_.end());
+ }
+
+ std::vector<AbstractSimplex> AbstractSimplex::facets() const
+ {
+ std::vector<AbstractSimplex> result;
+ for (int i = 0; i < static_cast<int>(vertices_.size()); ++i) {
+ std::vector<int> facet_vertices;
+ facet_vertices.reserve(dim());
+ for (int j = 0; j < static_cast<int>(vertices_.size()); ++j) {
+ if (j != i)
+ facet_vertices.push_back(vertices_[j]);
+ }
+ if (!facet_vertices.empty()) {
+ result.emplace_back(facet_vertices, false);
+ }
+ }
+ return result;
+ }
+
+ Simplex::Simplex(md::Index id, md::Point birth, int dim, const md::Column& bdry)
+ :
+ id_(id),
+ pos_(birth),
+ dim_(dim),
+ facet_indices_(bdry) { }
+
+ void Simplex::translate(Real a)
+ {
+ pos_.x += a;
+ pos_.y += a;
+ }
+
+ void Simplex::init_rivet(std::string s)
+ {
+// throw std::runtime_error("Not implemented");
+ auto delim_pos = s.find_first_of(";");
+ assert(delim_pos > 0);
+ std::string vertices_str = s.substr(0, delim_pos);
+ std::string pos_str = s.substr(delim_pos + 1);
+ assert(not vertices_str.empty() and not pos_str.empty());
+ // get vertices
+ std::stringstream vertices_ss(vertices_str);
+ int dim = 0;
+ int vertex;
+ while (vertices_ss >> vertex) {
+ dim++;
+ vertices_.push_back(vertex);
+ }
+ //
+ std::sort(vertices_.begin(), vertices_.end());
+ assert(dim > 0);
+
+ std::stringstream pos_ss(pos_str);
+ // TODO: get rid of 1-criticaltiy assumption
+ pos_ss >> pos_.x >> pos_.y;
+ }
+
+ void Simplex::init_rene(std::string s)
+ {
+ facet_indices_.clear();
+ std::stringstream ss(s);
+ ss >> dim_ >> pos_.x >> pos_.y;
+ if (dim_ > 0) {
+ facet_indices_.reserve(dim_ + 1);
+ for (int j = 0; j <= dim_; j++) {
+ Index k;
+ ss >> k;
+ facet_indices_.push_back(k);
+ }
+ }
+ }
+
+ Simplex::Simplex(Index _id, std::string s, BifiltrationFormat input_format)
+ :id_(_id)
+ {
+ switch (input_format) {
+ case BifiltrationFormat::rene :
+ init_rene(s);
+ break;
+ case BifiltrationFormat::rivet :
+ init_rivet(s);
+ break;
+ }
+ }
+
+ std::ostream& operator<<(std::ostream& os, const Simplex& x)
+ {
+ os << "Simplex(id = " << x.id() << ", dim = " << x.dim();
+ os << ", boundary = " << container_to_string(x.boundary()) << ", pos = " << x.position() << ")";
+ return os;
+ }
+}
diff --git a/matching/src/test_generator.cpp b/matching/src/test_generator.cpp
new file mode 100644
index 0000000..a0d8fc7
--- /dev/null
+++ b/matching/src/test_generator.cpp
@@ -0,0 +1,208 @@
+#include <map>
+#include <vector>
+#include <random>
+#include <iostream>
+#include <algorithm>
+
+#include "opts/opts.h"
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/ostr.h"
+
+#include "common_util.h"
+#include "bifiltration.h"
+
+using Index = md::Index;
+using Point = md::Point;
+using Column = md::Column;
+
+int g_max_coord = 100;
+
+using ASimplex = md::AbstractSimplex;
+
+using ASimplexToBirthMap = std::map<ASimplex, Point>;
+
+namespace spd = spdlog;
+
+// random generator is global
+std::random_device rd;
+std::mt19937_64 gen(rd());
+
+//std::mt19937_64 gen(42);
+
+Point get_random_position(int max_coord)
+{
+ assert(max_coord > 0);
+ std::uniform_int_distribution<int> distr(0, max_coord);
+ return Point(distr(gen), distr(gen));
+
+}
+
+Point get_random_position_less_than(Point ub)
+{
+ std::uniform_int_distribution<int> distr_x(0, ub.x);
+ std::uniform_int_distribution<int> distr_y(0, ub.y);
+ return Point(distr_x(gen), distr_y(gen));
+}
+
+Point get_random_position_greater_than(Point lb)
+{
+ std::uniform_int_distribution<int> distr_x(lb.x, g_max_coord);
+ std::uniform_int_distribution<int> distr_y(lb.y, g_max_coord);
+ return Point(distr_x(gen), distr_y(gen));
+}
+
+// non-proper faces (empty and simplex itself) are also faces
+bool is_face(const ASimplex& face_candidate, const ASimplex& coface_candidate)
+{
+ return std::includes(coface_candidate.begin(), coface_candidate.end(), face_candidate.begin(),
+ face_candidate.end());
+}
+
+bool is_top_simplex(const ASimplex& candidate_simplex, const std::vector<ASimplex>& current_top_simplices)
+{
+ return std::none_of(current_top_simplices.begin(), current_top_simplices.end(),
+ [&candidate_simplex](const ASimplex& ts) { return is_face(candidate_simplex, ts); });
+}
+
+void add_if_top(const ASimplex& candidate_simplex, std::vector<ASimplex>& current_top_simplices)
+{
+ // check that candidate simplex is not face of someone in current_top_simplices
+ if (!is_top_simplex(candidate_simplex, current_top_simplices))
+ return;
+
+ spd::debug("candidate_simplex is top, will be added to top_simplices");
+ // remove s from currrent_top_simplices, if s is face of candidate_simplex
+ current_top_simplices.erase(std::remove_if(current_top_simplices.begin(), current_top_simplices.end(),
+ [candidate_simplex](const ASimplex& s) { return is_face(s, candidate_simplex); }),
+ current_top_simplices.end());
+
+ current_top_simplices.push_back(candidate_simplex);
+}
+
+ASimplex get_random_simplex(int n_vertices, int dim)
+{
+ std::vector<int> all_vertices(n_vertices, 0);
+ // fill in with 0..n-1
+ std::iota(all_vertices.begin(), all_vertices.end(), 0);
+ std::shuffle(all_vertices.begin(), all_vertices.end(), gen);
+ return ASimplex(all_vertices.begin(), all_vertices.begin() + dim + 1, true);
+}
+
+void generate_positions(const ASimplex& s, ASimplexToBirthMap& simplex_to_birth, Point upper_bound)
+{
+ auto pos = get_random_position_less_than(upper_bound);
+ auto curr_pos_iter = simplex_to_birth.find(s);
+ if (curr_pos_iter != simplex_to_birth.end())
+ pos = md::greatest_lower_bound(pos, curr_pos_iter->second);
+ simplex_to_birth[s] = pos;
+ for(const ASimplex& facet : s.facets()) {
+ generate_positions(facet, simplex_to_birth, pos);
+ }
+}
+
+md::Bifiltration get_random_bifiltration(int n_vertices, int max_dim, int n_top_simplices)
+{
+ ASimplexToBirthMap simplex_to_birth;
+
+ // generate vertices
+ for(int i = 0; i < n_vertices; ++i) {
+ Point vertex_birth = get_random_position(g_max_coord / 10);
+ ASimplex vertex;
+ vertex.push_back(i);
+ simplex_to_birth[vertex] = vertex_birth;
+ }
+
+ std::vector<ASimplex> top_simplices;
+ // generate top simplices
+ while((int)top_simplices.size() < n_top_simplices) {
+ std::uniform_int_distribution<int> dimension_distr(1, max_dim);
+ int dim = dimension_distr(gen);
+ auto candidate_simplex = get_random_simplex(n_vertices, dim);
+ spd::debug("candidate_simplex = {}", candidate_simplex);
+ add_if_top(candidate_simplex, top_simplices);
+ }
+
+ Point upper_bound{static_cast<md::Real>(g_max_coord), static_cast<md::Real>(g_max_coord)};
+ for(const auto& top_simplex : top_simplices) {
+ generate_positions(top_simplex, simplex_to_birth, upper_bound);
+ }
+
+ std::vector<std::pair<ASimplex, Point>> simplex_birth_pairs{simplex_to_birth.begin(), simplex_to_birth.end()};
+ std::vector<md::Column> boundaries{simplex_to_birth.size(), md::Column()};
+
+// assign ids and save boundaries
+ int id = 0;
+
+ for(int dim = 0; dim <= max_dim; ++dim) {
+ for(int i = 0; i < (int) simplex_birth_pairs.size(); ++i) {
+ ASimplex& simplex = simplex_birth_pairs[i].first;
+ if (simplex.dim() == dim) {
+ simplex.id = id++;
+ md::Column bdry;
+ for(auto& facet : simplex.facets()) {
+ auto facet_iter = std::find_if(simplex_birth_pairs.begin(), simplex_birth_pairs.end(),
+ [facet](const std::pair<ASimplex, Point>& sbp) { return facet == sbp.first; });
+ assert(facet_iter != simplex_birth_pairs.end());
+ assert(facet_iter->first.id >= 0);
+ bdry.push_back(facet_iter->first.id);
+ }
+ std::sort(bdry.begin(), bdry.end());
+ boundaries[i] = bdry;
+ }
+ }
+ }
+
+// create vector of Simplex-es
+ std::vector<md::Simplex> simplices;
+ for(int i = 0; i < (int) simplex_birth_pairs.size(); ++i) {
+ int id = simplex_birth_pairs[i].first.id;
+ int dim = simplex_birth_pairs[i].first.dim();
+ Point birth = simplex_birth_pairs[i].second;
+ Column bdry = boundaries[i];
+ simplices.emplace_back(id, birth, dim, bdry);
+ }
+
+// sort by id
+ std::sort(simplices.begin(), simplices.end(),
+ [](const md::Simplex& s1, const md::Simplex& s2) { return s1.id() < s2.id(); });
+ for(int i = 0; i < (int)simplices.size(); ++i) {
+ assert(simplices[i].id() == i);
+ assert(i == 0 || simplices[i].dim() >= simplices[i - 1].dim());
+ }
+
+ return md::Bifiltration(simplices.begin(), simplices.end());
+}
+
+int main(int argc, char** argv)
+{
+ spd::set_level(spd::level::info);
+ int n_vertices;
+ int max_dim;
+ int n_top_simplices;
+ std::string fname;
+
+ using opts::Option;
+ using opts::PosOption;
+ opts::Options ops;
+
+ bool help = false;
+
+ ops >> Option('v', "n-vertices", n_vertices, "number of vertices")
+ >> Option('d', "max-dim", max_dim, "maximal dim")
+ >> Option('m', "max-coord", g_max_coord, "maximal coordinate")
+ >> Option('t', "n-top-simplices", n_top_simplices, "number of top simplices")
+ >> Option('h', "help", help, "show help message");
+
+ if (!ops.parse(argc, argv) || help || !(ops >> PosOption(fname))) {
+ std::cerr << "Usage: " << argv[0] << "\n" << ops << std::endl;
+ return 1;
+ }
+
+
+ auto bif1 = get_random_bifiltration(n_vertices, max_dim, n_top_simplices);
+ std::cout << "Generated bifiltration." << std::endl;
+ bif1.save(fname, md::BifiltrationFormat::rene);
+ std::cout << "Saved to file " << fname << std::endl;
+ return 0;
+}
+
diff --git a/matching/src/tests/prism_1_lesnick.bif b/matching/src/tests/prism_1_lesnick.bif
new file mode 100644
index 0000000..416c40c
--- /dev/null
+++ b/matching/src/tests/prism_1_lesnick.bif
@@ -0,0 +1,27 @@
+26
+0 0 0
+0 0 0
+0 0 0
+0 0 0
+0 0 0
+0 0 0
+1 1 0 1 0
+1 1 0 2 1
+1 0 0 3 1
+1 0 1 5 4
+1 0 0 4 1
+1 0 0 3 0
+1 0 0 5 0
+1 0 0 4 2
+1 0 0 5 2
+1 0 1 4 3
+1 1 0 2 0
+1 0 1 5 3
+2 1 1 13 10 7
+2 1 1 9 14 13
+2 1 1 8 11 6
+2 1 1 15 10 8
+2 1 1 14 12 16
+2 1 1 17 12 11
+2 4 0 7 16 6
+2 0 4 9 17 15
diff --git a/matching/src/tests/prism_2_lesnick.bif b/matching/src/tests/prism_2_lesnick.bif
new file mode 100644
index 0000000..e11faa9
--- /dev/null
+++ b/matching/src/tests/prism_2_lesnick.bif
@@ -0,0 +1,28 @@
+27
+0 0 0
+0 0 0
+0 0 0
+0 0 0
+0 0 0
+0 0 0
+1 0 0 2 1
+1 0 0 1 0
+1 0 0 5 3
+1 0 0 3 1
+1 0 0 5 4
+1 0 0 2 0
+1 0 0 4 1
+1 0 0 3 2
+1 0 0 5 2
+1 0 0 4 3
+1 0 0 4 2
+1 0 0 3 0
+2 0 0 16 12 6
+2 0 0 10 14 16
+2 0 0 9 17 7
+2 0 0 15 12 9
+2 0 0 13 17 11
+2 0 0 8 14 13
+2 4 0 6 11 7
+2 0 4 10 8 15
+2 3 3 15 16 13
diff --git a/matching/src/tests/test_bifiltration.cpp b/matching/src/tests/test_bifiltration.cpp
new file mode 100644
index 0000000..caa50ac
--- /dev/null
+++ b/matching/src/tests/test_bifiltration.cpp
@@ -0,0 +1,36 @@
+#include "catch/catch.hpp"
+
+#include <sstream>
+#include <iostream>
+
+#include "common_util.h"
+#include "box.h"
+#include "bifiltration.h"
+
+using namespace md;
+
+//TEST_CASE("Small check", "[bifiltration][dim2]")
+//{
+// Bifiltration bif("/home/narn/code/matching_distance/code/src/tests/test_bifiltration_full_triangle_rene.txt", BifiltrationFormat::rene);
+// auto simplices = bif.simplices();
+// bif.sanity_check();
+//
+// REQUIRE( simplices.size() == 7 );
+//
+// REQUIRE( simplices[0].dim() == 0 );
+// REQUIRE( simplices[1].dim() == 0 );
+// REQUIRE( simplices[2].dim() == 0 );
+// REQUIRE( simplices[3].dim() == 1 );
+// REQUIRE( simplices[4].dim() == 1 );
+// REQUIRE( simplices[5].dim() == 1 );
+// REQUIRE( simplices[6].dim() == 2);
+//
+// REQUIRE( simplices[0].position() == Point(0, 0));
+// REQUIRE( simplices[1].position() == Point(0, 0));
+// REQUIRE( simplices[2].position() == Point(0, 0));
+// REQUIRE( simplices[3].position() == Point(3, 1));
+// REQUIRE( simplices[6].position() == Point(30, 40));
+//
+// Line line_1(Line::pi / 2.0, 0.0);
+// auto dgm = bif.slice_diagram(line_1);
+//}
diff --git a/matching/src/tests/test_bifiltration_1.txt b/matching/src/tests/test_bifiltration_1.txt
new file mode 120000
index 0000000..ddd23e9
--- /dev/null
+++ b/matching/src/tests/test_bifiltration_1.txt
@@ -0,0 +1 @@
+../../data/test_bifiltration_1.txt \ No newline at end of file
diff --git a/matching/src/tests/test_bifiltration_full_triangle_rene.txt b/matching/src/tests/test_bifiltration_full_triangle_rene.txt
new file mode 120000
index 0000000..47f49fd
--- /dev/null
+++ b/matching/src/tests/test_bifiltration_full_triangle_rene.txt
@@ -0,0 +1 @@
+../../data/test_bifiltration_full_triangle_rene.txt \ No newline at end of file
diff --git a/matching/src/tests/test_common.cpp b/matching/src/tests/test_common.cpp
new file mode 100644
index 0000000..30465e4
--- /dev/null
+++ b/matching/src/tests/test_common.cpp
@@ -0,0 +1,190 @@
+#include "catch/catch.hpp"
+
+#include <sstream>
+#include <iostream>
+#include <string>
+
+#include "common_util.h"
+#include "simplex.h"
+#include "matching_distance.h"
+
+using namespace md;
+
+TEST_CASE("Rational", "[common_utils][rational]")
+{
+ // gcd
+ REQUIRE(gcd(10, 5) == 5);
+ REQUIRE(gcd(5, 10) == 5);
+ REQUIRE(gcd(5, 7) == 1);
+ REQUIRE(gcd(7, 5) == 1);
+ REQUIRE(gcd(13, 0) == 13);
+ REQUIRE(gcd(0, 13) == 13);
+ REQUIRE(gcd(16, 24) == 8);
+ REQUIRE(gcd(24, 16) == 8);
+ REQUIRE(gcd(16, 32) == 16);
+ REQUIRE(gcd(32, 16) == 16);
+
+
+ // reduce
+ REQUIRE(reduce({2, 1}) == std::make_pair(2, 1));
+ REQUIRE(reduce({1, 2}) == std::make_pair(1, 2));
+ REQUIRE(reduce({2, 2}) == std::make_pair(1, 1));
+ REQUIRE(reduce({0, 2}) == std::make_pair(0, 1));
+ REQUIRE(reduce({0, 20}) == std::make_pair(0, 1));
+ REQUIRE(reduce({35, 49}) == std::make_pair(5, 7));
+ REQUIRE(reduce({35, 25}) == std::make_pair(7, 5));
+
+ // midpoint
+ REQUIRE(midpoint(Rational {0, 1}, Rational {1, 2}) == std::make_pair(1, 4));
+ REQUIRE(midpoint(Rational {1, 4}, Rational {1, 2}) == std::make_pair(3, 8));
+ REQUIRE(midpoint(Rational {1, 2}, Rational {1, 2}) == std::make_pair(1, 2));
+ REQUIRE(midpoint(Rational {1, 2}, Rational {1, 1}) == std::make_pair(3, 4));
+ REQUIRE(midpoint(Rational {3, 7}, Rational {5, 14}) == std::make_pair(11, 28));
+
+
+ // arithmetic
+
+ REQUIRE(Rational(1, 2) + Rational(3, 5) == Rational(11, 10));
+ REQUIRE(Rational(2, 5) - Rational(3, 10) == Rational(1, 10));
+ REQUIRE(Rational(2, 3) * Rational(4, 7) == Rational(8, 21));
+ REQUIRE(Rational(2, 3) * Rational(3, 2) == Rational(1));
+ REQUIRE(Rational(2, 3) / Rational(3, 2) == Rational(4, 9));
+ REQUIRE(Rational(1, 2) * Rational(3, 5) == Rational(3, 10));
+
+ // comparison
+ REQUIRE(Rational(100000, 2000000) < Rational(100001, 2000000));
+ REQUIRE(!(Rational(100001, 2000000) < Rational(100000, 2000000)));
+ REQUIRE(!(Rational(100000, 2000000) < Rational(100000, 2000000)));
+ REQUIRE(Rational(-100000, 2000000) < Rational(100001, 2000000));
+ REQUIRE(Rational(-100001, 2000000) < Rational(100000, 2000000));
+};
+
+TEST_CASE("AbstractSimplex", "[abstract_simplex]")
+{
+ AbstractSimplex as;
+ REQUIRE(as.dim() == -1);
+
+ as.push_back(1);
+ REQUIRE(as.dim() == 0);
+ REQUIRE(as.facets().size() == 0);
+
+ as.push_back(0);
+ REQUIRE(as.dim() == 1);
+ REQUIRE(as.facets().size() == 2);
+ REQUIRE(as.facets()[0].dim() == 0);
+ REQUIRE(as.facets()[1].dim() == 0);
+
+}
+
+TEST_CASE("Vertical line", "[vertical_line]")
+{
+ // line x = 1
+ DualPoint l_vertical(AxisType::x_type, AngleType::steep, 0, 1);
+
+ REQUIRE(l_vertical.is_vertical());
+ REQUIRE(l_vertical.is_steep());
+
+ Point p_1(0.5, 0.5);
+ Point p_2(1.5, 0.5);
+ Point p_3(1.5, 1.5);
+ Point p_4(0.5, 1.5);
+ Point p_5(1, 10);
+
+ REQUIRE(l_vertical.x_from_y(10) == 1);
+ REQUIRE(l_vertical.x_from_y(-10) == 1);
+ REQUIRE(l_vertical.x_from_y(0) == 1);
+
+ REQUIRE(not l_vertical.contains(p_1));
+ REQUIRE(not l_vertical.contains(p_2));
+ REQUIRE(not l_vertical.contains(p_3));
+ REQUIRE(not l_vertical.contains(p_4));
+ REQUIRE(l_vertical.contains(p_5));
+
+ REQUIRE(l_vertical.goes_below(p_1));
+ REQUIRE(not l_vertical.goes_below(p_2));
+ REQUIRE(not l_vertical.goes_below(p_3));
+ REQUIRE(l_vertical.goes_below(p_4));
+
+ REQUIRE(not l_vertical.goes_above(p_1));
+ REQUIRE(l_vertical.goes_above(p_2));
+ REQUIRE(l_vertical.goes_above(p_3));
+ REQUIRE(not l_vertical.goes_above(p_4));
+
+}
+
+TEST_CASE("Horizontal line", "[horizontal_line]")
+{
+ // line y = 1
+ DualPoint l_horizontal(AxisType::y_type, AngleType::flat, 0, 1);
+
+ REQUIRE(l_horizontal.is_horizontal());
+ REQUIRE(l_horizontal.is_flat());
+ REQUIRE(l_horizontal.y_slope() == 0);
+ REQUIRE(l_horizontal.y_intercept() == 1);
+
+ Point p_1(0.5, 0.5);
+ Point p_2(1.5, 0.5);
+ Point p_3(1.5, 1.5);
+ Point p_4(0.5, 1.5);
+ Point p_5(2, 1);
+
+ REQUIRE((not l_horizontal.contains(p_1) and
+ not l_horizontal.contains(p_2) and
+ not l_horizontal.contains(p_3) and
+ not l_horizontal.contains(p_4) and
+ l_horizontal.contains(p_5)));
+
+ REQUIRE(not l_horizontal.goes_below(p_1));
+ REQUIRE(not l_horizontal.goes_below(p_2));
+ REQUIRE(l_horizontal.goes_below(p_3));
+ REQUIRE(l_horizontal.goes_below(p_4));
+ REQUIRE(l_horizontal.goes_below(p_5));
+
+ REQUIRE(l_horizontal.goes_above(p_1));
+ REQUIRE(l_horizontal.goes_above(p_2));
+ REQUIRE(not l_horizontal.goes_above(p_3));
+ REQUIRE(not l_horizontal.goes_above(p_4));
+ REQUIRE(l_horizontal.goes_above(p_5));
+}
+
+TEST_CASE("Flat Line with positive slope", "[flat_line]")
+{
+ // line y = x / 2 + 1
+ DualPoint l_flat(AxisType::y_type, AngleType::flat, 0.5, 1);
+
+ REQUIRE(not l_flat.is_horizontal());
+ REQUIRE(l_flat.is_flat());
+ REQUIRE(l_flat.y_slope() == 0.5);
+ REQUIRE(l_flat.y_intercept() == 1);
+
+ REQUIRE(l_flat.y_from_x(0) == 1);
+ REQUIRE(l_flat.y_from_x(1) == 1.5);
+ REQUIRE(l_flat.y_from_x(2) == 2);
+
+ Point p_1(3, 2);
+ Point p_2(-2, 0.01);
+ Point p_3(0, 1.25);
+ Point p_4(6, 4.5);
+ Point p_5(2, 2);
+
+ std::cout << "AHOY " << l_flat.y_from_x(p_2.x) << std::endl;
+
+ REQUIRE((not l_flat.contains(p_1) and
+ not l_flat.contains(p_2) and
+ not l_flat.contains(p_3) and
+ not l_flat.contains(p_4) and
+ l_flat.contains(p_5)));
+
+ REQUIRE(not l_flat.goes_below(p_1));
+ REQUIRE(l_flat.goes_below(p_2));
+ REQUIRE(l_flat.goes_below(p_3));
+ REQUIRE(l_flat.goes_below(p_4));
+ REQUIRE(l_flat.goes_below(p_5));
+
+ REQUIRE(l_flat.goes_above(p_1));
+ REQUIRE(not l_flat.goes_above(p_2));
+ REQUIRE(not l_flat.goes_above(p_3));
+ REQUIRE(not l_flat.goes_above(p_4));
+ REQUIRE(l_flat.goes_above(p_5));
+
+}
diff --git a/matching/src/tests/test_list.txt b/matching/src/tests/test_list.txt
new file mode 100644
index 0000000..1984606
--- /dev/null
+++ b/matching/src/tests/test_list.txt
@@ -0,0 +1 @@
+prism_lesnick_1.bif prism_lesnick_2.bif 1.0
diff --git a/matching/src/tests/test_matching_distance.cpp b/matching/src/tests/test_matching_distance.cpp
new file mode 100644
index 0000000..a54e18e
--- /dev/null
+++ b/matching/src/tests/test_matching_distance.cpp
@@ -0,0 +1,146 @@
+#include "catch/catch.hpp"
+
+#include <sstream>
+#include <iostream>
+#include <string>
+
+#include "spdlog/spdlog.h"
+#include "spdlog/fmt/ostr.h"
+
+#include "common_util.h"
+#include "simplex.h"
+#include "matching_distance.h"
+
+using namespace md;
+namespace spd = spdlog;
+
+TEST_CASE("Different bounds", "[bounds]")
+{
+ std::vector<Simplex> simplices;
+ std::vector<Point> points;
+
+ Real max_x = 10;
+ Real max_y = 20;
+
+ int simplex_id = 0;
+ for(int i = 0; i <= max_x; ++i) {
+ for(int j = 0; j <= max_y; ++j) {
+ Point p(i, j);
+ simplices.emplace_back(simplex_id++, p, 0, Column());
+ points.push_back(p);
+ }
+ }
+
+ Bifiltration bif_a(simplices.begin(), simplices.end());
+ Bifiltration bif_b(simplices.begin(), simplices.end());
+
+ CalculationParams params;
+ params.initialization_depth = 2;
+
+ DistanceCalculator calc(bif_a, bif_b, params);
+
+// REQUIRE(calc.max_x_ == Approx(max_x));
+// REQUIRE(calc.max_y_ == Approx(max_y));
+
+ std::vector<DualBox> boxes;
+
+ for(CellWithValue c : calc.get_refined_grid(5, false, false)) {
+ boxes.push_back(c.dual_box());
+ }
+
+ // fill in boxes and points
+
+ for(DualBox db : boxes) {
+ Real local_bound = calc.get_local_dual_bound(db);
+ Real local_bound_refined = calc.get_local_refined_bound(db);
+ REQUIRE(local_bound >= local_bound_refined);
+ for(Point p : points) {
+ for(ValuePoint vp_a : k_corner_vps) {
+ CellWithValue dual_cell(db, 1);
+ DualPoint corner_a = dual_cell.value_point(vp_a);
+ Real wp_a = corner_a.weighted_push(p);
+ dual_cell.set_value_at(vp_a, wp_a);
+ Real point_bound = calc.get_max_displacement_single_point(dual_cell, vp_a, p);
+ for(ValuePoint vp_b : k_corner_vps) {
+ if (vp_b <= vp_a)
+ continue;
+ DualPoint corner_b = dual_cell.value_point(vp_b);
+ Real wp_b = corner_b.weighted_push(p);
+ Real diff = fabs(wp_a - wp_b);
+ if (not(point_bound <= Approx(local_bound_refined))) {
+ std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound
+ << ", refined local = " << local_bound_refined << std::endl;
+ spd::set_level(spd::level::debug);
+ calc.get_max_displacement_single_point(dual_cell, vp_a, p);
+ calc.get_local_refined_bound(db);
+ spd::set_level(spd::level::info);
+ }
+
+ REQUIRE(point_bound <= Approx(local_bound_refined));
+ REQUIRE(diff <= Approx(point_bound));
+ REQUIRE(diff <= Approx(local_bound_refined));
+ }
+
+ for(DualPoint l_random : db.random_points(100)) {
+ Real wp_random = l_random.weighted_push(p);
+ Real diff = fabs(wp_a - wp_random);
+ if (not(diff <= Approx(point_bound))) {
+ if (db.critical_points(p).size() > 4) {
+ std::cerr << "ERROR interesting case" << std::endl;
+ } else {
+ std::cerr << "ERROR boring case" << std::endl;
+ }
+ spd::set_level(spd::level::debug);
+ l_random.weighted_push(p);
+ spd::set_level(spd::level::info);
+ std::cerr << "ERROR point: " << p << ", box = " << db << ", point bound = " << point_bound
+ << ", refined local = " << local_bound_refined;
+ std::cerr << ", random_dual_point = " << l_random << ", wp_random = " << wp_random
+ << ", diff = " << diff << std::endl;
+ }
+ REQUIRE(diff <= Approx(point_bound));
+ }
+ }
+ }
+ }
+}
+
+TEST_CASE("Bifiltrations from file", "[matching_distance][small_example][lesnick]")
+{
+ std::string fname_a, fname_b;
+
+ fname_a = "/home/narn/code/matching_distance/code/python_scripts/prism_1_lesnick.bif";
+ fname_b = "/home/narn/code/matching_distance/code/python_scripts/prism_2_lesnick.bif";
+
+ Bifiltration bif_a(fname_a, BifiltrationFormat::rene);
+ Bifiltration bif_b(fname_b, BifiltrationFormat::rene);
+
+ CalculationParams params;
+
+ std::vector<BoundStrategy> bound_strategies {BoundStrategy::local_combined,
+ BoundStrategy::local_dual_bound_refined};
+
+ std::vector<TraverseStrategy> traverse_strategies {TraverseStrategy::breadth_first, TraverseStrategy::depth_first};
+
+ std::vector<double> scaling_factors {10, 1.0};
+
+ for(auto bs : bound_strategies) {
+ for(auto ts : traverse_strategies) {
+ for(double lambda : scaling_factors) {
+ Bifiltration bif_a_copy(bif_a);
+ Bifiltration bif_b_copy(bif_b);
+ bif_a_copy.scale(lambda);
+ bif_b_copy.scale(lambda);
+ params.bound_strategy = bs;
+ params.traverse_strategy = ts;
+ params.max_depth = 7;
+ params.delta = 0.01;
+ params.dim = 1;
+ Real answer = matching_distance(bif_a_copy, bif_b_copy, params);
+ Real correct_answer = lambda * 1.0;
+ REQUIRE(fabs(answer - correct_answer) < lambda * 0.05);
+ }
+ }
+ }
+}
+
diff --git a/matching/src/tests/tests_main.cpp b/matching/src/tests/tests_main.cpp
new file mode 100644
index 0000000..1c77b13
--- /dev/null
+++ b/matching/src/tests/tests_main.cpp
@@ -0,0 +1,2 @@
+#define CATCH_CONFIG_MAIN
+#include "catch/catch.hpp"